Skip to content

Commit 2ff091b

Browse files
committed
extract create_function_optimizer
1 parent 30e51b8 commit 2ff091b

File tree

1 file changed

+40
-40
lines changed

1 file changed

+40
-40
lines changed

codeflash/optimization/optimizer.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,35 @@ def create_function_optimizer(
131131
function_to_optimize_source_code: str | None = "",
132132
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
133133
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
134-
) -> FunctionOptimizer:
134+
original_module_ast: ast.Module | None = None,
135+
original_module_path: Path | None = None,
136+
) -> FunctionOptimizer | None:
137+
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
135138
from codeflash.optimization.function_optimizer import FunctionOptimizer
136139

140+
if function_to_optimize_ast is None and original_module_ast is not None:
141+
function_to_optimize_ast = get_first_top_level_function_or_method_ast(
142+
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
143+
)
144+
if function_to_optimize_ast is None:
145+
logger.info(
146+
f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n"
147+
f"Skipping optimization."
148+
)
149+
return None
150+
151+
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
152+
153+
function_specific_timings = None
154+
if (
155+
hasattr(self.args, "benchmark")
156+
and self.args.benchmark
157+
and function_benchmark_timings
158+
and qualified_name_w_module in function_benchmark_timings
159+
and total_benchmark_timings
160+
):
161+
function_specific_timings = function_benchmark_timings[qualified_name_w_module]
162+
137163
return FunctionOptimizer(
138164
function_to_optimize=function_to_optimize,
139165
test_cfg=self.test_cfg,
@@ -142,18 +168,15 @@ def create_function_optimizer(
142168
function_to_optimize_ast=function_to_optimize_ast,
143169
aiservice_client=self.aiservice_client,
144170
args=self.args,
145-
function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None,
146-
total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None,
171+
function_benchmark_timings=function_specific_timings,
172+
total_benchmark_timings=total_benchmark_timings if function_specific_timings else None,
147173
replay_tests_dir=self.replay_tests_dir,
148174
)
149175

150176
def run(self) -> None:
151177
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
152178
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
153-
from codeflash.code_utils.static_analysis import (
154-
analyze_imported_modules,
155-
get_first_top_level_function_or_method_ast,
156-
)
179+
from codeflash.code_utils.static_analysis import analyze_imported_modules
157180
from codeflash.discovery.discover_unit_tests import discover_unit_tests
158181

159182
ph("cli-optimize-run-start")
@@ -245,40 +268,17 @@ def run(self) -> None:
245268
f"{function_to_optimize.qualified_name}"
246269
)
247270
console.rule()
248-
if not (
249-
function_to_optimize_ast := get_first_top_level_function_or_method_ast(
250-
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
251-
)
252-
):
253-
logger.info(
254-
f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n"
255-
f"Skipping optimization."
256-
)
257-
continue
258-
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(
259-
self.args.project_root
271+
272+
function_optimizer = self.create_function_optimizer(
273+
function_to_optimize,
274+
function_to_tests=function_to_tests,
275+
function_to_optimize_source_code=validated_original_code[original_module_path].source_code,
276+
function_benchmark_timings=function_benchmark_timings,
277+
total_benchmark_timings=total_benchmark_timings,
278+
original_module_ast=original_module_ast,
279+
original_module_path=original_module_path,
260280
)
261-
if (
262-
self.args.benchmark
263-
and function_benchmark_timings
264-
and qualified_name_w_module in function_benchmark_timings
265-
and total_benchmark_timings
266-
):
267-
function_optimizer = self.create_function_optimizer(
268-
function_to_optimize,
269-
function_to_optimize_ast,
270-
function_to_tests,
271-
validated_original_code[original_module_path].source_code,
272-
function_benchmark_timings[qualified_name_w_module],
273-
total_benchmark_timings,
274-
)
275-
else:
276-
function_optimizer = self.create_function_optimizer(
277-
function_to_optimize,
278-
function_to_optimize_ast,
279-
function_to_tests,
280-
validated_original_code[original_module_path].source_code,
281-
)
281+
282282
self.current_function_optimizer = (
283283
function_optimizer # needed to clean up from the outside of this function
284284
)

0 commit comments

Comments
 (0)