Skip to content

Commit c9aaaad

Browse files
committed
don't reapply too early
1 parent 0a57afa commit c9aaaad

File tree

2 files changed

+35
-26
lines changed

2 files changed

+35
-26
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ def find_and_update_line_node(
7777
call_node = node
7878
if isinstance(node.func, ast.Name):
7979
function_name = node.func.id
80-
80+
8181
if self.function_object.is_async:
8282
return [test_node]
83-
83+
8484
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
8585
node.args = [
8686
ast.Name(id=function_name, ctx=ast.Load()),
@@ -104,7 +104,7 @@ def find_and_update_line_node(
104104
if function_to_test == self.function_object.function_name:
105105
if self.function_object.is_async:
106106
return [test_node]
107-
107+
108108
function_name = ast.unparse(node.func)
109109
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
110110
node.args = [
@@ -359,11 +359,10 @@ def inject_profiling_into_existing_test(
359359
tests_project_root: Path,
360360
test_framework: str,
361361
mode: TestingMode = TestingMode.BEHAVIOR,
362-
source_module_path: Path | None = None,
363362
) -> tuple[bool, str | None]:
364363
if function_to_optimize.is_async:
365364
return False, None
366-
365+
367366
with test_path.open(encoding="utf8") as f:
368367
test_code = f.read()
369368
try:
@@ -390,7 +389,7 @@ def inject_profiling_into_existing_test(
390389
if test_framework == "unittest":
391390
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
392391
additional_functions = [create_wrapper_function(mode)]
393-
392+
394393
tree.body = [*new_imports, *additional_functions, *tree.body]
395394
return True, isort.code(ast.unparse(tree), float_to_top=True)
396395

@@ -925,5 +924,3 @@ def add_async_decorator_to_function(
925924
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:
926925
instrumented_filename = f"instrumented_{source_path.name}"
927926
return temp_dir / instrumented_filename
928-
929-

codeflash/optimization/function_optimizer.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -893,15 +893,14 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
893893
else:
894894
msg = f"Unexpected test type: {test_type}"
895895
raise ValueError(msg)
896-
896+
897897
success, injected_behavior_test = inject_profiling_into_existing_test(
898898
mode=TestingMode.BEHAVIOR,
899899
test_path=path_obj_test_file,
900900
call_positions=[test.position for test in tests_in_file_list],
901901
function_to_optimize=self.function_to_optimize,
902902
tests_project_root=self.test_cfg.tests_project_rootdir,
903903
test_framework=self.args.test_framework,
904-
source_module_path=self.function_to_optimize.file_path,
905904
)
906905
if not success:
907906
continue
@@ -912,7 +911,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
912911
function_to_optimize=self.function_to_optimize,
913912
tests_project_root=self.test_cfg.tests_project_rootdir,
914913
test_framework=self.args.test_framework,
915-
source_module_path=self.function_to_optimize.file_path,
916914
)
917915
if not success:
918916
continue
@@ -1368,15 +1366,18 @@ def establish_original_code_baseline(
13681366
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
13691367

13701368
if self.function_to_optimize.is_async:
1371-
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
1369+
from codeflash.code_utils.instrument_existing_tests import (
1370+
instrument_source_module_with_async_decorators,
1371+
)
1372+
13721373
success, instrumented_source = instrument_source_module_with_async_decorators(
13731374
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
13741375
)
13751376
if success and instrumented_source:
13761377
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
13771378
f.write(instrumented_source)
13781379
logger.debug(f"Applied async instrumentation to {self.function_to_optimize.file_path}")
1379-
1380+
13801381
# Instrument codeflash capture
13811382
try:
13821383
instrument_codeflash_capture(
@@ -1392,9 +1393,10 @@ def establish_original_code_baseline(
13921393
code_context=code_context,
13931394
)
13941395
finally:
1395-
self.write_code_and_helpers(
1396-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1397-
)
1396+
if not self.function_to_optimize.is_async:
1397+
self.write_code_and_helpers(
1398+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1399+
)
13981400
if not behavioral_results:
13991401
logger.warning(
14001402
f"Couldn't run any tests for original function {self.function_to_optimize.function_name}. SKIPPING OPTIMIZING THIS FUNCTION."
@@ -1408,18 +1410,22 @@ def establish_original_code_baseline(
14081410
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
14091411
)
14101412
console.rule()
1411-
1412-
# For async functions, instrument at definition site for performance benchmarking
1413+
14131414
if self.function_to_optimize.is_async:
1414-
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
1415+
from codeflash.code_utils.instrument_existing_tests import (
1416+
instrument_source_module_with_async_decorators,
1417+
)
1418+
14151419
success, instrumented_source = instrument_source_module_with_async_decorators(
14161420
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
14171421
)
14181422
if success and instrumented_source:
14191423
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
14201424
f.write(instrumented_source)
1421-
logger.debug(f"Applied async performance instrumentation to {self.function_to_optimize.file_path}")
1422-
1425+
logger.debug(
1426+
f"Applied async performance instrumentation to {self.function_to_optimize.file_path}"
1427+
)
1428+
14231429
try:
14241430
benchmarking_results, _ = self.run_and_parse_tests(
14251431
testing_type=TestingMode.PERFORMANCE,
@@ -1431,10 +1437,11 @@ def establish_original_code_baseline(
14311437
code_context=code_context,
14321438
)
14331439
finally:
1434-
# Restore original source if we instrumented it
14351440
if self.function_to_optimize.is_async:
14361441
self.write_code_and_helpers(
1437-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1442+
self.function_to_optimize_source_code,
1443+
original_helper_code,
1444+
self.function_to_optimize.file_path,
14381445
)
14391446
else:
14401447
benchmarking_results = TestResults()
@@ -1570,15 +1577,20 @@ def run_optimized_candidate(
15701577
if test_framework == "pytest":
15711578
# For async functions, instrument at definition site for performance benchmarking
15721579
if self.function_to_optimize.is_async:
1573-
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
1580+
from codeflash.code_utils.instrument_existing_tests import (
1581+
instrument_source_module_with_async_decorators,
1582+
)
1583+
15741584
success, instrumented_source = instrument_source_module_with_async_decorators(
15751585
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
15761586
)
15771587
if success and instrumented_source:
15781588
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
15791589
f.write(instrumented_source)
1580-
logger.debug(f"Applied async performance instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}")
1581-
1590+
logger.debug(
1591+
f"Applied async performance instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
1592+
)
1593+
15821594
try:
15831595
candidate_benchmarking_results, _ = self.run_and_parse_tests(
15841596
testing_type=TestingMode.PERFORMANCE,

0 commit comments

Comments
 (0)