@@ -131,9 +131,35 @@ def create_function_optimizer(
131131 function_to_optimize_source_code : str | None = "" ,
132132 function_benchmark_timings : dict [str , dict [BenchmarkKey , float ]] | None = None ,
133133 total_benchmark_timings : dict [BenchmarkKey , float ] | None = None ,
134- ) -> FunctionOptimizer :
134+ original_module_ast : ast .Module | None = None ,
135+ original_module_path : Path | None = None ,
136+ ) -> FunctionOptimizer | None :
137+ from codeflash .code_utils .static_analysis import get_first_top_level_function_or_method_ast
135138 from codeflash .optimization .function_optimizer import FunctionOptimizer
136139
140+ if function_to_optimize_ast is None and original_module_ast is not None :
141+ function_to_optimize_ast = get_first_top_level_function_or_method_ast (
142+ function_to_optimize .function_name , function_to_optimize .parents , original_module_ast
143+ )
144+ if function_to_optimize_ast is None :
145+ logger .info (
146+ f"Function { function_to_optimize .qualified_name } not found in { original_module_path } .\n "
147+ f"Skipping optimization."
148+ )
149+ return None
150+
151+ qualified_name_w_module = function_to_optimize .qualified_name_with_modules_from_root (self .args .project_root )
152+
153+ function_specific_timings = None
154+ if (
155+ hasattr (self .args , "benchmark" )
156+ and self .args .benchmark
157+ and function_benchmark_timings
158+ and qualified_name_w_module in function_benchmark_timings
159+ and total_benchmark_timings
160+ ):
161+ function_specific_timings = function_benchmark_timings [qualified_name_w_module ]
162+
137163 return FunctionOptimizer (
138164 function_to_optimize = function_to_optimize ,
139165 test_cfg = self .test_cfg ,
@@ -142,18 +168,15 @@ def create_function_optimizer(
142168 function_to_optimize_ast = function_to_optimize_ast ,
143169 aiservice_client = self .aiservice_client ,
144170 args = self .args ,
145- function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else None ,
146- total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else None ,
171+ function_benchmark_timings = function_specific_timings ,
172+ total_benchmark_timings = total_benchmark_timings if function_specific_timings else None ,
147173 replay_tests_dir = self .replay_tests_dir ,
148174 )
149175
150176 def run (self ) -> None :
151177 from codeflash .code_utils .checkpoint import CodeflashRunCheckpoint
152178 from codeflash .code_utils .code_replacer import normalize_code , normalize_node
153- from codeflash .code_utils .static_analysis import (
154- analyze_imported_modules ,
155- get_first_top_level_function_or_method_ast ,
156- )
179+ from codeflash .code_utils .static_analysis import analyze_imported_modules
157180 from codeflash .discovery .discover_unit_tests import discover_unit_tests
158181
159182 ph ("cli-optimize-run-start" )
@@ -245,40 +268,17 @@ def run(self) -> None:
245268 f"{ function_to_optimize .qualified_name } "
246269 )
247270 console .rule ()
248- if not (
249- function_to_optimize_ast := get_first_top_level_function_or_method_ast (
250- function_to_optimize .function_name , function_to_optimize .parents , original_module_ast
251- )
252- ):
253- logger .info (
254- f"Function { function_to_optimize .qualified_name } not found in { original_module_path } .\n "
255- f"Skipping optimization."
256- )
257- continue
258- qualified_name_w_module = function_to_optimize .qualified_name_with_modules_from_root (
259- self .args .project_root
271+
272+ function_optimizer = self .create_function_optimizer (
273+ function_to_optimize ,
274+ function_to_tests = function_to_tests ,
275+ function_to_optimize_source_code = validated_original_code [original_module_path ].source_code ,
276+ function_benchmark_timings = function_benchmark_timings ,
277+ total_benchmark_timings = total_benchmark_timings ,
278+ original_module_ast = original_module_ast ,
279+ original_module_path = original_module_path ,
260280 )
261- if (
262- self .args .benchmark
263- and function_benchmark_timings
264- and qualified_name_w_module in function_benchmark_timings
265- and total_benchmark_timings
266- ):
267- function_optimizer = self .create_function_optimizer (
268- function_to_optimize ,
269- function_to_optimize_ast ,
270- function_to_tests ,
271- validated_original_code [original_module_path ].source_code ,
272- function_benchmark_timings [qualified_name_w_module ],
273- total_benchmark_timings ,
274- )
275- else :
276- function_optimizer = self .create_function_optimizer (
277- function_to_optimize ,
278- function_to_optimize_ast ,
279- function_to_tests ,
280- validated_original_code [original_module_path ].source_code ,
281- )
281+
282282 self .current_function_optimizer = (
283283 function_optimizer # needed to clean up from the outside of this function
284284 )
0 commit comments