2626
2727
2828def get_code_optimization_context (
29- function_to_optimize : FunctionToOptimize , project_root_path : Path , optim_token_limit : int = 8000 , testgen_token_limit : int = 8000
29+ function_to_optimize : FunctionToOptimize ,
30+ project_root_path : Path ,
31+ optim_token_limit : int = 8000 ,
32+ testgen_token_limit : int = 8000 ,
3033) -> CodeOptimizationContext :
3134 # Get FunctionSource representation of helpers of FTO
32- helpers_of_fto_dict , helpers_of_fto_list = get_function_sources_from_jedi ({function_to_optimize .file_path : {function_to_optimize .qualified_name }}, project_root_path )
35+ helpers_of_fto_dict , helpers_of_fto_list = get_function_sources_from_jedi (
36+ {function_to_optimize .file_path : {function_to_optimize .qualified_name }}, project_root_path
37+ )
3338
3439 # Add function to optimize into helpers of FTO dict, as they'll be processed together
3540 fto_as_function_source = get_function_to_optimize_as_function_source (function_to_optimize , project_root_path )
3641 helpers_of_fto_dict [function_to_optimize .file_path ].add (fto_as_function_source )
3742
3843 # Format data to search for helpers of helpers using get_function_sources_from_jedi
3944 helpers_of_fto_qualified_names_dict = {
40- file_path : {source .qualified_name for source in sources }
41- for file_path , sources in helpers_of_fto_dict .items ()
45+ file_path : {source .qualified_name for source in sources } for file_path , sources in helpers_of_fto_dict .items ()
4246 }
4347
4448 # __init__ functions are automatically considered as helpers of FTO, so we add them to the dict (regardless of whether they exist)
4549 # This helps us to search for helpers of __init__ functions of classes that contain helpers of FTO
4650 for qualified_names in helpers_of_fto_qualified_names_dict .values ():
47- qualified_names .update ({f"{ qn .rsplit ('.' , 1 )[0 ]} .__init__" for qn in qualified_names if '.' in qn })
51+ qualified_names .update ({f"{ qn .rsplit ('.' , 1 )[0 ]} .__init__" for qn in qualified_names if "." in qn })
4852
4953 # Get FunctionSource representation of helpers of helpers of FTO
50- helpers_of_helpers_dict , helpers_of_helpers_list = get_function_sources_from_jedi (helpers_of_fto_qualified_names_dict , project_root_path )
54+ helpers_of_helpers_dict , helpers_of_helpers_list = get_function_sources_from_jedi (
55+ helpers_of_fto_qualified_names_dict , project_root_path
56+ )
5157
5258 # Extract code context for optimization
53- final_read_writable_code = extract_code_string_context_from_files (helpers_of_fto_dict ,{}, project_root_path , remove_docstrings = False , code_context_type = CodeContextType .READ_WRITABLE ).code
59+ final_read_writable_code = extract_code_string_context_from_files (
60+ helpers_of_fto_dict ,
61+ {},
62+ project_root_path ,
63+ remove_docstrings = False ,
64+ code_context_type = CodeContextType .READ_WRITABLE ,
65+ ).code
5466 read_only_code_markdown = extract_code_markdown_context_from_files (
5567 helpers_of_fto_dict ,
5668 helpers_of_helpers_dict ,
@@ -80,10 +92,7 @@ def get_code_optimization_context(
8092 logger .debug ("Code context has exceeded token limit, removing docstrings from read-only code" )
8193 # Extract read only code without docstrings
8294 read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files (
83- helpers_of_fto_dict ,
84- helpers_of_helpers_dict ,
85- project_root_path ,
86- remove_docstrings = True ,
95+ helpers_of_fto_dict , helpers_of_helpers_dict , project_root_path , remove_docstrings = True
8796 )
8897 read_only_context_code = read_only_code_no_docstring_markdown .markdown
8998 read_only_code_no_docstring_markdown_tokens = len (tokenizer .encode (read_only_context_code ))
@@ -116,13 +125,14 @@ def get_code_optimization_context(
116125 raise ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
117126
118127 return CodeOptimizationContext (
119- testgen_context_code = testgen_context_code ,
128+ testgen_context_code = testgen_context_code ,
120129 read_writable_code = final_read_writable_code ,
121130 read_only_context_code = read_only_context_code ,
122131 helper_functions = helpers_of_fto_list ,
123132 preexisting_objects = preexisting_objects ,
124133 )
125134
135+
126136def extract_code_string_context_from_files (
127137 helpers_of_fto : dict [Path , set [FunctionSource ]],
128138 helpers_of_helpers : dict [Path , set [FunctionSource ]],
@@ -169,9 +179,15 @@ def extract_code_string_context_from_files(
169179 continue
170180 try :
171181 qualified_function_names = {func .qualified_name for func in function_sources }
172- helpers_of_helpers_qualified_names = {func .qualified_name for func in helpers_of_helpers .get (file_path , set ())}
182+ helpers_of_helpers_qualified_names = {
183+ func .qualified_name for func in helpers_of_helpers .get (file_path , set ())
184+ }
173185 code_context = parse_code_and_prune_cst (
174- original_code , code_context_type , qualified_function_names , helpers_of_helpers_qualified_names , remove_docstrings
186+ original_code ,
187+ code_context_type ,
188+ qualified_function_names ,
189+ helpers_of_helpers_qualified_names ,
190+ remove_docstrings ,
175191 )
176192
177193 except ValueError as e :
@@ -180,12 +196,12 @@ def extract_code_string_context_from_files(
180196 if code_context .strip ():
181197 final_code_string_context += f"\n { code_context } "
182198 final_code_string_context = add_needed_imports_from_module (
183- src_module_code = original_code ,
184- dst_module_code = final_code_string_context ,
185- src_path = file_path ,
186- dst_path = file_path ,
187- project_root = project_root_path ,
188- helper_functions = list (helpers_of_fto .get (file_path , set ()) | helpers_of_helpers .get (file_path , set ()))
199+ src_module_code = original_code ,
200+ dst_module_code = final_code_string_context ,
201+ src_path = file_path ,
202+ dst_path = file_path ,
203+ project_root = project_root_path ,
204+ helper_functions = list (helpers_of_fto .get (file_path , set ()) | helpers_of_helpers .get (file_path , set ())),
189205 )
190206 if code_context_type == CodeContextType .READ_WRITABLE :
191207 return CodeString (code = final_code_string_context )
@@ -199,7 +215,7 @@ def extract_code_string_context_from_files(
199215 try :
200216 qualified_helper_function_names = {func .qualified_name for func in helper_function_sources }
201217 code_context = parse_code_and_prune_cst (
202- original_code , code_context_type , set (), qualified_helper_function_names , remove_docstrings
218+ original_code , code_context_type , set (), qualified_helper_function_names , remove_docstrings
203219 )
204220 except ValueError as e :
205221 logger .debug (f"Error while getting read-only code: { e } " )
@@ -208,15 +224,16 @@ def extract_code_string_context_from_files(
208224 if code_context .strip ():
209225 final_code_string_context += f"\n { code_context } "
210226 final_code_string_context = add_needed_imports_from_module (
211- src_module_code = original_code ,
212- dst_module_code = final_code_string_context ,
213- src_path = file_path ,
214- dst_path = file_path ,
215- project_root = project_root_path ,
216- helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
227+ src_module_code = original_code ,
228+ dst_module_code = final_code_string_context ,
229+ src_path = file_path ,
230+ dst_path = file_path ,
231+ project_root = project_root_path ,
232+ helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
217233 )
218234 return CodeString (code = final_code_string_context )
219235
236+
220237def extract_code_markdown_context_from_files (
221238 helpers_of_fto : dict [Path , set [FunctionSource ]],
222239 helpers_of_helpers : dict [Path , set [FunctionSource ]],
@@ -263,9 +280,15 @@ def extract_code_markdown_context_from_files(
263280 continue
264281 try :
265282 qualified_function_names = {func .qualified_name for func in function_sources }
266- helpers_of_helpers_qualified_names = {func .qualified_name for func in helpers_of_helpers .get (file_path , set ())}
283+ helpers_of_helpers_qualified_names = {
284+ func .qualified_name for func in helpers_of_helpers .get (file_path , set ())
285+ }
267286 code_context = parse_code_and_prune_cst (
268- original_code , code_context_type , qualified_function_names , helpers_of_helpers_qualified_names , remove_docstrings
287+ original_code ,
288+ code_context_type ,
289+ qualified_function_names ,
290+ helpers_of_helpers_qualified_names ,
291+ remove_docstrings ,
269292 )
270293
271294 except ValueError as e :
@@ -280,7 +303,8 @@ def extract_code_markdown_context_from_files(
280303 dst_path = file_path ,
281304 project_root = project_root_path ,
282305 helper_functions = list (
283- helpers_of_fto .get (file_path , set ()) | helpers_of_helpers .get (file_path , set ()))
306+ helpers_of_fto .get (file_path , set ()) | helpers_of_helpers .get (file_path , set ())
307+ ),
284308 ),
285309 file_path = file_path .relative_to (project_root_path ),
286310 )
@@ -295,7 +319,7 @@ def extract_code_markdown_context_from_files(
295319 try :
296320 qualified_helper_function_names = {func .qualified_name for func in helper_function_sources }
297321 code_context = parse_code_and_prune_cst (
298- original_code , code_context_type , set (), qualified_helper_function_names , remove_docstrings ,
322+ original_code , code_context_type , set (), qualified_helper_function_names , remove_docstrings
299323 )
300324 except ValueError as e :
301325 logger .debug (f"Error while getting read-only code: { e } " )
@@ -317,8 +341,9 @@ def extract_code_markdown_context_from_files(
317341 return code_context_markdown
318342
319343
320- def get_function_to_optimize_as_function_source (function_to_optimize : FunctionToOptimize ,
321- project_root_path : Path ) -> FunctionSource :
344+ def get_function_to_optimize_as_function_source (
345+ function_to_optimize : FunctionToOptimize , project_root_path : Path
346+ ) -> FunctionSource :
322347 # Use jedi to find function to optimize
323348 script = jedi .Script (path = function_to_optimize .file_path , project = jedi .Project (path = project_root_path ))
324349
@@ -327,11 +352,12 @@ def get_function_to_optimize_as_function_source(function_to_optimize: FunctionTo
327352
328353 # Find the name that matches our function
329354 for name in names :
330- if (name .type == "function" and
331- name .full_name and
332- name .name == function_to_optimize .function_name and
333- get_qualified_name (name .module_name , name .full_name ) == function_to_optimize .qualified_name ):
334-
355+ if (
356+ name .type == "function"
357+ and name .full_name
358+ and name .name == function_to_optimize .function_name
359+ and get_qualified_name (name .module_name , name .full_name ) == function_to_optimize .qualified_name
360+ ):
335361 function_source = FunctionSource (
336362 file_path = function_to_optimize .file_path ,
337363 qualified_name = function_to_optimize .qualified_name ,
@@ -343,7 +369,8 @@ def get_function_to_optimize_as_function_source(function_to_optimize: FunctionTo
343369 return function_source
344370
345371 raise ValueError (
346- f"Could not find function { function_to_optimize .function_name } in { function_to_optimize .file_path } " )
372+ f"Could not find function { function_to_optimize .function_name } in { function_to_optimize .file_path } "
373+ )
347374
348375
349376def get_function_sources_from_jedi (
@@ -417,8 +444,13 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode
417444 return indented_block .with_changes (body = indented_block .body [1 :])
418445 return indented_block
419446
447+
420448def parse_code_and_prune_cst (
421- code : str , code_context_type : CodeContextType , target_functions : set [str ], helpers_of_helper_functions : set [str ] = set (), remove_docstrings : bool = False
449+ code : str ,
450+ code_context_type : CodeContextType ,
451+ target_functions : set [str ],
452+ helpers_of_helper_functions : set [str ] = set (),
453+ remove_docstrings : bool = False ,
422454) -> str :
423455 """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
424456 module = cst .parse_module (code )
@@ -441,6 +473,7 @@ def parse_code_and_prune_cst(
441473 return str (filtered_node .code )
442474 return ""
443475
476+
444477def prune_cst_for_read_writable_code (
445478 node : cst .CSTNode , target_functions : set [str ], prefix : str = ""
446479) -> tuple [cst .CSTNode | None , bool ]:
@@ -520,6 +553,7 @@ def prune_cst_for_read_writable_code(
520553
521554 return (node .with_changes (** updates ) if updates else node ), True
522555
556+
523557def prune_cst_for_read_only_code (
524558 node : cst .CSTNode ,
525559 target_functions : set [str ],
@@ -624,7 +658,6 @@ def prune_cst_for_read_only_code(
624658 return None , False
625659
626660
627-
628661def prune_cst_for_testgen_code (
629662 node : cst .CSTNode ,
630663 target_functions : set [str ],
0 commit comments