Skip to content

Commit 24f21a5

Browse files
committed
refactor
1 parent ca11dc6 commit 24f21a5

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

codeflash/optimization/optimizer.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,23 @@ def create_function_optimizer(
190190
function_to_optimize_source_code: str | None = "",
191191
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
192192
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
193-
) -> FunctionOptimizer:
193+
original_module_ast: ast.Module | None = None,
194+
original_module_path: Path | None = None,
195+
) -> FunctionOptimizer | None:
196+
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
194197
from codeflash.optimization.function_optimizer import FunctionOptimizer
195198

199+
if function_to_optimize_ast is None and original_module_ast is not None:
200+
function_to_optimize_ast = get_first_top_level_function_or_method_ast(
201+
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
202+
)
203+
if function_to_optimize_ast is None:
204+
logger.info(
205+
f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n"
206+
f"Skipping optimization."
207+
)
208+
return None
209+
196210
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
197211

198212
function_specific_timings = None
@@ -220,7 +234,6 @@ def create_function_optimizer(
220234
def run(self) -> None:
221235
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
222236
from codeflash.code_utils.code_utils import cleanup_paths
223-
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
224237

225238
ph("cli-optimize-run-start")
226239
logger.info("Running optimizer.")
@@ -266,26 +279,20 @@ def run(self) -> None:
266279
f"{function_to_optimize.qualified_name}"
267280
)
268281
console.rule()
269-
if not (
270-
function_to_optimize_ast := get_first_top_level_function_or_method_ast(
271-
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
272-
)
273-
):
274-
logger.info(
275-
f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n"
276-
f"Skipping optimization."
277-
)
278-
continue
279282

280283
function_optimizer = self.create_function_optimizer(
281284
function_to_optimize,
282-
function_to_optimize_ast,
283-
function_to_tests,
284-
validated_original_code[original_module_path].source_code,
285-
function_benchmark_timings,
286-
total_benchmark_timings,
285+
function_to_tests=function_to_tests,
286+
function_to_optimize_source_code=validated_original_code[original_module_path].source_code,
287+
function_benchmark_timings=function_benchmark_timings,
288+
total_benchmark_timings=total_benchmark_timings,
289+
original_module_ast=original_module_ast,
290+
original_module_path=original_module_path,
287291
)
288292

293+
if function_optimizer is None:
294+
continue
295+
289296
best_optimization = function_optimizer.optimize_function()
290297
if self.functions_checkpoint:
291298
self.functions_checkpoint.add_function_to_checkpoint(

0 commit comments

Comments
 (0)