@@ -893,15 +893,14 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
893893 else :
894894 msg = f"Unexpected test type: { test_type } "
895895 raise ValueError (msg )
896-
896+
897897 success , injected_behavior_test = inject_profiling_into_existing_test (
898898 mode = TestingMode .BEHAVIOR ,
899899 test_path = path_obj_test_file ,
900900 call_positions = [test .position for test in tests_in_file_list ],
901901 function_to_optimize = self .function_to_optimize ,
902902 tests_project_root = self .test_cfg .tests_project_rootdir ,
903903 test_framework = self .args .test_framework ,
904- source_module_path = self .function_to_optimize .file_path ,
905904 )
906905 if not success :
907906 continue
@@ -912,7 +911,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
912911 function_to_optimize = self .function_to_optimize ,
913912 tests_project_root = self .test_cfg .tests_project_rootdir ,
914913 test_framework = self .args .test_framework ,
915- source_module_path = self .function_to_optimize .file_path ,
916914 )
917915 if not success :
918916 continue
@@ -1368,15 +1366,18 @@ def establish_original_code_baseline(
13681366 test_env = self .get_test_env (codeflash_loop_index = 0 , codeflash_test_iteration = 0 , codeflash_tracer_disable = 1 )
13691367
13701368 if self .function_to_optimize .is_async :
1371- from codeflash .code_utils .instrument_existing_tests import instrument_source_module_with_async_decorators
1369+ from codeflash .code_utils .instrument_existing_tests import (
1370+ instrument_source_module_with_async_decorators ,
1371+ )
1372+
13721373 success , instrumented_source = instrument_source_module_with_async_decorators (
13731374 self .function_to_optimize .file_path , self .function_to_optimize , TestingMode .BEHAVIOR
13741375 )
13751376 if success and instrumented_source :
13761377 with self .function_to_optimize .file_path .open ("w" , encoding = "utf8" ) as f :
13771378 f .write (instrumented_source )
13781379 logger .debug (f"Applied async instrumentation to { self .function_to_optimize .file_path } " )
1379-
1380+
13801381 # Instrument codeflash capture
13811382 try :
13821383 instrument_codeflash_capture (
@@ -1392,9 +1393,10 @@ def establish_original_code_baseline(
13921393 code_context = code_context ,
13931394 )
13941395 finally :
1395- self .write_code_and_helpers (
1396- self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
1397- )
1396+ if not self .function_to_optimize .is_async :
1397+ self .write_code_and_helpers (
1398+ self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
1399+ )
13981400 if not behavioral_results :
13991401 logger .warning (
14001402 f"Couldn't run any tests for original function { self .function_to_optimize .function_name } . SKIPPING OPTIMIZING THIS FUNCTION."
@@ -1408,18 +1410,22 @@ def establish_original_code_baseline(
14081410 code_context = code_context , original_helper_code = original_helper_code , candidate_index = 0
14091411 )
14101412 console .rule ()
1411-
1412- # For async functions, instrument at definition site for performance benchmarking
1413+
14131414 if self .function_to_optimize .is_async :
1414- from codeflash .code_utils .instrument_existing_tests import instrument_source_module_with_async_decorators
1415+ from codeflash .code_utils .instrument_existing_tests import (
1416+ instrument_source_module_with_async_decorators ,
1417+ )
1418+
14151419 success , instrumented_source = instrument_source_module_with_async_decorators (
14161420 self .function_to_optimize .file_path , self .function_to_optimize , TestingMode .PERFORMANCE
14171421 )
14181422 if success and instrumented_source :
14191423 with self .function_to_optimize .file_path .open ("w" , encoding = "utf8" ) as f :
14201424 f .write (instrumented_source )
1421- logger .debug (f"Applied async performance instrumentation to { self .function_to_optimize .file_path } " )
1422-
1425+ logger .debug (
1426+ f"Applied async performance instrumentation to { self .function_to_optimize .file_path } "
1427+ )
1428+
14231429 try :
14241430 benchmarking_results , _ = self .run_and_parse_tests (
14251431 testing_type = TestingMode .PERFORMANCE ,
@@ -1431,10 +1437,11 @@ def establish_original_code_baseline(
14311437 code_context = code_context ,
14321438 )
14331439 finally :
1434- # Restore original source if we instrumented it
14351440 if self .function_to_optimize .is_async :
14361441 self .write_code_and_helpers (
1437- self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
1442+ self .function_to_optimize_source_code ,
1443+ original_helper_code ,
1444+ self .function_to_optimize .file_path ,
14381445 )
14391446 else :
14401447 benchmarking_results = TestResults ()
@@ -1570,15 +1577,20 @@ def run_optimized_candidate(
15701577 if test_framework == "pytest" :
15711578 # For async functions, instrument at definition site for performance benchmarking
15721579 if self .function_to_optimize .is_async :
1573- from codeflash .code_utils .instrument_existing_tests import instrument_source_module_with_async_decorators
1580+ from codeflash .code_utils .instrument_existing_tests import (
1581+ instrument_source_module_with_async_decorators ,
1582+ )
1583+
15741584 success , instrumented_source = instrument_source_module_with_async_decorators (
15751585 self .function_to_optimize .file_path , self .function_to_optimize , TestingMode .PERFORMANCE
15761586 )
15771587 if success and instrumented_source :
15781588 with self .function_to_optimize .file_path .open ("w" , encoding = "utf8" ) as f :
15791589 f .write (instrumented_source )
1580- logger .debug (f"Applied async performance instrumentation to { self .function_to_optimize .file_path } for candidate { optimization_candidate_index } " )
1581-
1590+ logger .debug (
1591+ f"Applied async performance instrumentation to { self .function_to_optimize .file_path } for candidate { optimization_candidate_index } "
1592+ )
1593+
15821594 try :
15831595 candidate_benchmarking_results , _ = self .run_and_parse_tests (
15841596 testing_type = TestingMode .PERFORMANCE ,
0 commit comments