2626
2727
2828def  get_code_optimization_context (
29-     function_to_optimize : FunctionToOptimize , project_root_path : Path , optim_token_limit : int  =  8000 , testgen_token_limit : int  =  8000 
29+     function_to_optimize : FunctionToOptimize ,
30+     project_root_path : Path ,
31+     optim_token_limit : int  =  8000 ,
32+     testgen_token_limit : int  =  8000 ,
3033) ->  CodeOptimizationContext :
3134    # Get FunctionSource representation of helpers of FTO 
32-     helpers_of_fto_dict , helpers_of_fto_list  =  get_function_sources_from_jedi ({function_to_optimize .file_path : {function_to_optimize .qualified_name }}, project_root_path )
35+     helpers_of_fto_dict , helpers_of_fto_list  =  get_function_sources_from_jedi (
36+         {function_to_optimize .file_path : {function_to_optimize .qualified_name }}, project_root_path 
37+     )
3338
3439    # Add function to optimize into helpers of FTO dict, as they'll be processed together 
3540    fto_as_function_source  =  get_function_to_optimize_as_function_source (function_to_optimize , project_root_path )
3641    helpers_of_fto_dict [function_to_optimize .file_path ].add (fto_as_function_source )
3742
3843    # Format data to search for helpers of helpers using get_function_sources_from_jedi 
3944    helpers_of_fto_qualified_names_dict  =  {
40-         file_path : {source .qualified_name  for  source  in  sources }
41-         for  file_path , sources  in  helpers_of_fto_dict .items ()
45+         file_path : {source .qualified_name  for  source  in  sources } for  file_path , sources  in  helpers_of_fto_dict .items ()
4246    }
4347
4448    # __init__ functions are automatically considered as helpers of FTO, so we add them to the dict (regardless of whether they exist) 
4549    # This helps us to search for helpers of __init__ functions of classes that contain helpers of FTO 
4650    for  qualified_names  in  helpers_of_fto_qualified_names_dict .values ():
47-           qualified_names .update ({f"{ qn .rsplit ('.' , 1 )[0 ]}  .__init__"  for  qn  in  qualified_names  if  '.'  in  qn })
51+         qualified_names .update ({f"{ qn .rsplit ('.' , 1 )[0 ]}  .__init__"  for  qn  in  qualified_names  if  "."  in  qn })
4852
4953    # Get FunctionSource representation of helpers of helpers of FTO 
50-     helpers_of_helpers_dict , helpers_of_helpers_list  =  get_function_sources_from_jedi (helpers_of_fto_qualified_names_dict , project_root_path )
54+     helpers_of_helpers_dict , helpers_of_helpers_list  =  get_function_sources_from_jedi (
55+         helpers_of_fto_qualified_names_dict , project_root_path 
56+     )
5157
5258    # Extract code context for optimization 
53-     final_read_writable_code  =  extract_code_string_context_from_files (helpers_of_fto_dict ,{}, project_root_path , remove_docstrings = False , code_context_type = CodeContextType .READ_WRITABLE ).code 
59+     final_read_writable_code  =  extract_code_string_context_from_files (
60+         helpers_of_fto_dict ,
61+         {},
62+         project_root_path ,
63+         remove_docstrings = False ,
64+         code_context_type = CodeContextType .READ_WRITABLE ,
65+     ).code 
5466    read_only_code_markdown  =  extract_code_markdown_context_from_files (
5567        helpers_of_fto_dict ,
5668        helpers_of_helpers_dict ,
@@ -80,10 +92,7 @@ def get_code_optimization_context(
8092        logger .debug ("Code context has exceeded token limit, removing docstrings from read-only code" )
8193        # Extract read only code without docstrings 
8294        read_only_code_no_docstring_markdown  =  extract_code_markdown_context_from_files (
83-             helpers_of_fto_dict ,
84-             helpers_of_helpers_dict ,
85-             project_root_path ,
86-             remove_docstrings = True ,
95+             helpers_of_fto_dict , helpers_of_helpers_dict , project_root_path , remove_docstrings = True 
8796        )
8897        read_only_context_code  =  read_only_code_no_docstring_markdown .markdown 
8998        read_only_code_no_docstring_markdown_tokens  =  len (tokenizer .encode (read_only_context_code ))
@@ -116,13 +125,14 @@ def get_code_optimization_context(
116125            raise  ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
117126
118127    return  CodeOptimizationContext (
119-         testgen_context_code   =   testgen_context_code ,
128+         testgen_context_code = testgen_context_code ,
120129        read_writable_code = final_read_writable_code ,
121130        read_only_context_code = read_only_context_code ,
122131        helper_functions = helpers_of_fto_list ,
123132        preexisting_objects = preexisting_objects ,
124133    )
125134
135+ 
126136def  extract_code_string_context_from_files (
127137    helpers_of_fto : dict [Path , set [FunctionSource ]],
128138    helpers_of_helpers : dict [Path , set [FunctionSource ]],
@@ -169,9 +179,15 @@ def extract_code_string_context_from_files(
169179            continue 
170180        try :
171181            qualified_function_names  =  {func .qualified_name  for  func  in  function_sources }
172-             helpers_of_helpers_qualified_names  =  {func .qualified_name  for  func  in  helpers_of_helpers .get (file_path , set ())}
182+             helpers_of_helpers_qualified_names  =  {
183+                 func .qualified_name  for  func  in  helpers_of_helpers .get (file_path , set ())
184+             }
173185            code_context  =  parse_code_and_prune_cst (
174-                 original_code ,  code_context_type , qualified_function_names , helpers_of_helpers_qualified_names , remove_docstrings 
186+                 original_code ,
187+                 code_context_type ,
188+                 qualified_function_names ,
189+                 helpers_of_helpers_qualified_names ,
190+                 remove_docstrings ,
175191            )
176192
177193        except  ValueError  as  e :
@@ -180,12 +196,12 @@ def extract_code_string_context_from_files(
180196        if  code_context .strip ():
181197            final_code_string_context  +=  f"\n { code_context }  " 
182198            final_code_string_context  =  add_needed_imports_from_module (
183-                      src_module_code = original_code ,
184-                      dst_module_code = final_code_string_context ,
185-                      src_path = file_path ,
186-                      dst_path = file_path ,
187-                      project_root = project_root_path ,
188-                      helper_functions =   list (helpers_of_fto .get (file_path , set ()) |  helpers_of_helpers .get (file_path , set ()))
199+                 src_module_code = original_code ,
200+                 dst_module_code = final_code_string_context ,
201+                 src_path = file_path ,
202+                 dst_path = file_path ,
203+                 project_root = project_root_path ,
204+                 helper_functions = list (helpers_of_fto .get (file_path , set ()) |  helpers_of_helpers .get (file_path , set ())), 
189205            )
190206    if  code_context_type  ==  CodeContextType .READ_WRITABLE :
191207        return  CodeString (code = final_code_string_context )
@@ -199,7 +215,7 @@ def extract_code_string_context_from_files(
199215        try :
200216            qualified_helper_function_names  =  {func .qualified_name  for  func  in  helper_function_sources }
201217            code_context  =  parse_code_and_prune_cst (
202-                 original_code , code_context_type , set (), qualified_helper_function_names ,   remove_docstrings 
218+                 original_code , code_context_type , set (), qualified_helper_function_names , remove_docstrings 
203219            )
204220        except  ValueError  as  e :
205221            logger .debug (f"Error while getting read-only code: { e }  " )
@@ -208,15 +224,16 @@ def extract_code_string_context_from_files(
208224        if  code_context .strip ():
209225            final_code_string_context  +=  f"\n { code_context }  " 
210226            final_code_string_context  =  add_needed_imports_from_module (
211-                      src_module_code = original_code ,
212-                      dst_module_code = final_code_string_context ,
213-                      src_path = file_path ,
214-                      dst_path = file_path ,
215-                      project_root = project_root_path ,
216-                      helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
227+                 src_module_code = original_code ,
228+                 dst_module_code = final_code_string_context ,
229+                 src_path = file_path ,
230+                 dst_path = file_path ,
231+                 project_root = project_root_path ,
232+                 helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
217233            )
218234    return  CodeString (code = final_code_string_context )
219235
236+ 
220237def  extract_code_markdown_context_from_files (
221238    helpers_of_fto : dict [Path , set [FunctionSource ]],
222239    helpers_of_helpers : dict [Path , set [FunctionSource ]],
@@ -263,9 +280,15 @@ def extract_code_markdown_context_from_files(
263280            continue 
264281        try :
265282            qualified_function_names  =  {func .qualified_name  for  func  in  function_sources }
266-             helpers_of_helpers_qualified_names  =  {func .qualified_name  for  func  in  helpers_of_helpers .get (file_path , set ())}
283+             helpers_of_helpers_qualified_names  =  {
284+                 func .qualified_name  for  func  in  helpers_of_helpers .get (file_path , set ())
285+             }
267286            code_context  =  parse_code_and_prune_cst (
268-                 original_code ,  code_context_type , qualified_function_names , helpers_of_helpers_qualified_names , remove_docstrings 
287+                 original_code ,
288+                 code_context_type ,
289+                 qualified_function_names ,
290+                 helpers_of_helpers_qualified_names ,
291+                 remove_docstrings ,
269292            )
270293
271294        except  ValueError  as  e :
@@ -280,7 +303,8 @@ def extract_code_markdown_context_from_files(
280303                    dst_path = file_path ,
281304                    project_root = project_root_path ,
282305                    helper_functions = list (
283-                         helpers_of_fto .get (file_path , set ()) |  helpers_of_helpers .get (file_path , set ()))
306+                         helpers_of_fto .get (file_path , set ()) |  helpers_of_helpers .get (file_path , set ())
307+                     ),
284308                ),
285309                file_path = file_path .relative_to (project_root_path ),
286310            )
@@ -295,7 +319,7 @@ def extract_code_markdown_context_from_files(
295319        try :
296320            qualified_helper_function_names  =  {func .qualified_name  for  func  in  helper_function_sources }
297321            code_context  =  parse_code_and_prune_cst (
298-                 original_code , code_context_type , set (), qualified_helper_function_names , remove_docstrings , 
322+                 original_code , code_context_type , set (), qualified_helper_function_names , remove_docstrings 
299323            )
300324        except  ValueError  as  e :
301325            logger .debug (f"Error while getting read-only code: { e }  " )
@@ -317,8 +341,9 @@ def extract_code_markdown_context_from_files(
317341    return  code_context_markdown 
318342
319343
320- def  get_function_to_optimize_as_function_source (function_to_optimize : FunctionToOptimize ,
321-                                        project_root_path : Path ) ->  FunctionSource :
344+ def  get_function_to_optimize_as_function_source (
345+     function_to_optimize : FunctionToOptimize , project_root_path : Path 
346+ ) ->  FunctionSource :
322347    # Use jedi to find function to optimize 
323348    script  =  jedi .Script (path = function_to_optimize .file_path , project = jedi .Project (path = project_root_path ))
324349
@@ -327,11 +352,12 @@ def get_function_to_optimize_as_function_source(function_to_optimize: FunctionTo
327352
328353    # Find the name that matches our function 
329354    for  name  in  names :
330-         if  (name .type  ==  "function"  and 
331-         name .full_name  and 
332-                 name .name  ==  function_to_optimize .function_name  and 
333-                 get_qualified_name (name .module_name , name .full_name ) ==  function_to_optimize .qualified_name ):
334- 
355+         if  (
356+             name .type  ==  "function" 
357+             and  name .full_name 
358+             and  name .name  ==  function_to_optimize .function_name 
359+             and  get_qualified_name (name .module_name , name .full_name ) ==  function_to_optimize .qualified_name 
360+         ):
335361            function_source  =  FunctionSource (
336362                file_path = function_to_optimize .file_path ,
337363                qualified_name = function_to_optimize .qualified_name ,
@@ -343,7 +369,8 @@ def get_function_to_optimize_as_function_source(function_to_optimize: FunctionTo
343369            return  function_source 
344370
345371    raise  ValueError (
346-         f"Could not find function { function_to_optimize .function_name }   in { function_to_optimize .file_path }  " )
372+         f"Could not find function { function_to_optimize .function_name }   in { function_to_optimize .file_path }  " 
373+     )
347374
348375
349376def  get_function_sources_from_jedi (
@@ -417,8 +444,13 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode
417444        return  indented_block .with_changes (body = indented_block .body [1 :])
418445    return  indented_block 
419446
447+ 
420448def  parse_code_and_prune_cst (
421-     code : str , code_context_type : CodeContextType , target_functions : set [str ], helpers_of_helper_functions : set [str ] =  set (), remove_docstrings : bool  =  False 
449+     code : str ,
450+     code_context_type : CodeContextType ,
451+     target_functions : set [str ],
452+     helpers_of_helper_functions : set [str ] =  set (),
453+     remove_docstrings : bool  =  False ,
422454) ->  str :
423455    """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables.""" 
424456    module  =  cst .parse_module (code )
@@ -441,6 +473,7 @@ def parse_code_and_prune_cst(
441473        return  str (filtered_node .code )
442474    return  "" 
443475
476+ 
444477def  prune_cst_for_read_writable_code (
445478    node : cst .CSTNode , target_functions : set [str ], prefix : str  =  "" 
446479) ->  tuple [cst .CSTNode  |  None , bool ]:
@@ -520,6 +553,7 @@ def prune_cst_for_read_writable_code(
520553
521554    return  (node .with_changes (** updates ) if  updates  else  node ), True 
522555
556+ 
523557def  prune_cst_for_read_only_code (
524558    node : cst .CSTNode ,
525559    target_functions : set [str ],
@@ -624,7 +658,6 @@ def prune_cst_for_read_only_code(
624658    return  None , False 
625659
626660
627- 
628661def  prune_cst_for_testgen_code (
629662    node : cst .CSTNode ,
630663    target_functions : set [str ],
0 commit comments