22
33import  ast 
44import  os 
5- import  shutil 
65import  tempfile 
76import  time 
87from  collections  import  defaultdict 
1817from  codeflash .cli_cmds .console  import  console , logger , progress_bar 
1918from  codeflash .code_utils  import  env_utils 
2019from  codeflash .code_utils .code_replacer  import  normalize_code , normalize_node 
21- from  codeflash .code_utils .code_utils  import  get_run_tmp_file 
20+ from  codeflash .code_utils .code_utils  import  cleanup_paths ,  get_run_tmp_file 
2221from  codeflash .code_utils .static_analysis  import  analyze_imported_modules , get_first_top_level_function_or_method_ast 
2322from  codeflash .discovery .discover_unit_tests  import  discover_unit_tests 
2423from  codeflash .discovery .functions_to_optimize  import  get_functions_to_optimize 
@@ -52,6 +51,11 @@ def __init__(self, args: Namespace) -> None:
5251        self .experiment_id  =  os .getenv ("CODEFLASH_EXPERIMENT_ID" , None )
5352        self .local_aiservice_client  =  LocalAiServiceClient () if  self .experiment_id  else  None 
5453        self .replay_tests_dir  =  None 
54+         if  self .args .test_framework  ==  "pytest" :
55+             self .test_cfg .concolic_test_root_dir  =  Path (
56+                 tempfile .mkdtemp (dir = self .args .tests_root , prefix = "codeflash_concolic_" )
57+             )
58+ 
5559    def  create_function_optimizer (
5660        self ,
5761        function_to_optimize : FunctionToOptimize ,
@@ -71,7 +75,7 @@ def create_function_optimizer(
7175            args = self .args ,
7276            function_benchmark_timings = function_benchmark_timings  if  function_benchmark_timings  else  None ,
7377            total_benchmark_timings = total_benchmark_timings  if  total_benchmark_timings  else  None ,
74-             replay_tests_dir   =   self .replay_tests_dir 
78+             replay_tests_dir = self .replay_tests_dir , 
7579        )
7680
7781    def  run (self ) ->  None :
@@ -81,6 +85,7 @@ def run(self) -> None:
8185        if  not  env_utils .ensure_codeflash_api_key ():
8286            return 
8387        function_optimizer  =  None 
88+         trace_file  =  None 
8489        file_to_funcs_to_optimize : dict [Path , list [FunctionToOptimize ]]
8590        num_optimizable_functions : int 
8691
@@ -98,10 +103,7 @@ def run(self) -> None:
98103        function_benchmark_timings : dict [str , dict [BenchmarkKey , int ]] =  {}
99104        total_benchmark_timings : dict [BenchmarkKey , int ] =  {}
100105        if  self .args .benchmark  and  num_optimizable_functions  >  0 :
101-             with  progress_bar (
102-                     f"Running benchmarks in { self .args .benchmarks_root }  " ,
103-                     transient = True ,
104-             ):
106+             with  progress_bar (f"Running benchmarks in { self .args .benchmarks_root }  " , transient = True ):
105107                # Insert decorator 
106108                file_path_to_source_code  =  defaultdict (str )
107109                for  file  in  file_to_funcs_to_optimize :
@@ -113,30 +115,35 @@ def run(self) -> None:
113115                    if  trace_file .exists ():
114116                        trace_file .unlink ()
115117
116-                     self .replay_tests_dir  =  Path (tempfile .mkdtemp (prefix = "codeflash_replay_tests_" , dir = self .args .benchmarks_root ))
117-                     trace_benchmarks_pytest (self .args .benchmarks_root , self .args .tests_root , self .args .project_root , trace_file ) # Run all tests that use pytest-benchmark 
118+                     self .replay_tests_dir  =  Path (
119+                         tempfile .mkdtemp (prefix = "codeflash_replay_tests_" , dir = self .args .benchmarks_root )
120+                     )
121+                     trace_benchmarks_pytest (
122+                         self .args .benchmarks_root , self .args .tests_root , self .args .project_root , trace_file 
123+                     )  # Run all tests that use pytest-benchmark 
118124                    replay_count  =  generate_replay_test (trace_file , self .replay_tests_dir )
119125                    if  replay_count  ==  0 :
120-                         logger .info (f"No valid benchmarks found in { self .args .benchmarks_root }   for functions to optimize, continuing optimization" )
126+                         logger .info (
127+                             f"No valid benchmarks found in { self .args .benchmarks_root }   for functions to optimize, continuing optimization" 
128+                         )
121129                    else :
122130                        function_benchmark_timings  =  CodeFlashBenchmarkPlugin .get_function_benchmark_timings (trace_file )
123131                        total_benchmark_timings  =  CodeFlashBenchmarkPlugin .get_benchmark_timings (trace_file )
124-                         function_to_results  =  validate_and_format_benchmark_table (function_benchmark_timings , total_benchmark_timings )
132+                         function_to_results  =  validate_and_format_benchmark_table (
133+                             function_benchmark_timings , total_benchmark_timings 
134+                         )
125135                        print_benchmark_table (function_to_results )
126136                except  Exception  as  e :
127137                    logger .info (f"Error while tracing existing benchmarks: { e }  " )
128138                    logger .info ("Information on existing benchmarks will not be available for this run." )
139+                     self .cleanup (function_optimizer = None )
129140                finally :
130141                    # Restore original source code 
131142                    for  file  in  file_path_to_source_code :
132143                        with  file .open ("w" , encoding = "utf8" ) as  f :
133144                            f .write (file_path_to_source_code [file ])
134145        optimizations_found : int  =  0 
135146        function_iterator_count : int  =  0 
136-         if  self .args .test_framework  ==  "pytest" :
137-             self .test_cfg .concolic_test_root_dir  =  Path (
138-                 tempfile .mkdtemp (dir = self .args .tests_root , prefix = "codeflash_concolic_" )
139-             )
140147        try :
141148            ph ("cli-optimize-functions-to-optimize" , {"num_functions" : num_optimizable_functions })
142149            if  num_optimizable_functions  ==  0 :
@@ -148,11 +155,12 @@ def run(self) -> None:
148155            function_to_tests : dict [str , list [FunctionCalledInTest ]] =  discover_unit_tests (self .test_cfg )
149156            num_discovered_tests : int  =  sum ([len (value ) for  value  in  function_to_tests .values ()])
150157            console .rule ()
151-             logger .info (f"Discovered { num_discovered_tests }   existing unit tests in { (time .time () -  start_time ):.1f}  s at { self .test_cfg .tests_root }  " )
158+             logger .info (
159+                 f"Discovered { num_discovered_tests }   existing unit tests in { (time .time () -  start_time ):.1f}  s at { self .test_cfg .tests_root }  " 
160+             )
152161            console .rule ()
153162            ph ("cli-optimize-discovered-tests" , {"num_tests" : num_discovered_tests })
154163
155- 
156164            for  original_module_path  in  file_to_funcs_to_optimize :
157165                logger .info (f"Examining file { original_module_path !s}  …" )
158166                console .rule ()
@@ -212,14 +220,26 @@ def run(self) -> None:
212220                    qualified_name_w_module  =  function_to_optimize .qualified_name_with_modules_from_root (
213221                        self .args .project_root 
214222                    )
215-                     if  self .args .benchmark  and  function_benchmark_timings  and  qualified_name_w_module  in  function_benchmark_timings  and  total_benchmark_timings :
223+                     if  (
224+                         self .args .benchmark 
225+                         and  function_benchmark_timings 
226+                         and  qualified_name_w_module  in  function_benchmark_timings 
227+                         and  total_benchmark_timings 
228+                     ):
216229                        function_optimizer  =  self .create_function_optimizer (
217-                             function_to_optimize , function_to_optimize_ast , function_to_tests , validated_original_code [original_module_path ].source_code , function_benchmark_timings [qualified_name_w_module ], total_benchmark_timings 
230+                             function_to_optimize ,
231+                             function_to_optimize_ast ,
232+                             function_to_tests ,
233+                             validated_original_code [original_module_path ].source_code ,
234+                             function_benchmark_timings [qualified_name_w_module ],
235+                             total_benchmark_timings ,
218236                        )
219237                    else :
220238                        function_optimizer  =  self .create_function_optimizer (
221-                             function_to_optimize , function_to_optimize_ast , function_to_tests ,
222-                             validated_original_code [original_module_path ].source_code 
239+                             function_to_optimize ,
240+                             function_to_optimize_ast ,
241+                             function_to_tests ,
242+                             validated_original_code [original_module_path ].source_code ,
223243                        )
224244
225245                    best_optimization  =  function_optimizer .optimize_function ()
@@ -235,23 +255,42 @@ def run(self) -> None:
235255            elif  self .args .all :
236256                logger .info ("✨ All functions have been optimized! ✨" )
237257        finally :
238-             if  function_optimizer :
239-                 for  test_file  in  function_optimizer .test_files .get_by_type (TestType .GENERATED_REGRESSION ).test_files :
240-                     test_file .instrumented_behavior_file_path .unlink (missing_ok = True )
241-                     test_file .benchmarking_file_path .unlink (missing_ok = True )
242-                 for  test_file  in  function_optimizer .test_files .get_by_type (TestType .EXISTING_UNIT_TEST ).test_files :
243-                     test_file .instrumented_behavior_file_path .unlink (missing_ok = True )
244-                     test_file .benchmarking_file_path .unlink (missing_ok = True )
245-                 for  test_file  in  function_optimizer .test_files .get_by_type (TestType .CONCOLIC_COVERAGE_TEST ).test_files :
246-                     test_file .instrumented_behavior_file_path .unlink (missing_ok = True )
247-                 if  function_optimizer .test_cfg .concolic_test_root_dir :
248-                     shutil .rmtree (function_optimizer .test_cfg .concolic_test_root_dir , ignore_errors = True )
249-                 if  self .args .benchmark :
250-                     if  self .replay_tests_dir .exists ():
251-                         shutil .rmtree (self .replay_tests_dir , ignore_errors = True )
252-                     trace_file .unlink (missing_ok = True )
253-             if  hasattr (get_run_tmp_file , "tmpdir" ):
254-                 get_run_tmp_file .tmpdir .cleanup ()
258+             self .cleanup (function_optimizer = function_optimizer )
259+ 
260+     def  cleanup (self , function_optimizer : FunctionOptimizer  |  None ) ->  None :
261+         paths_to_cleanup : list [Path ] =  []
262+         if  function_optimizer :
263+             paths_to_cleanup .extend (
264+                 test_file .instrumented_behavior_file_path 
265+                 for  test_file  in  function_optimizer .test_files .get_by_type (TestType .GENERATED_REGRESSION ).test_files 
266+             )
267+             paths_to_cleanup .extend (
268+                 test_file .benchmarking_file_path 
269+                 for  test_file  in  function_optimizer .test_files .get_by_type (TestType .GENERATED_REGRESSION ).test_files 
270+             )
271+             paths_to_cleanup .extend (
272+                 test_file .instrumented_behavior_file_path 
273+                 for  test_file  in  function_optimizer .test_files .get_by_type (TestType .EXISTING_UNIT_TEST ).test_files 
274+             )
275+             paths_to_cleanup .extend (
276+                 test_file .benchmarking_file_path 
277+                 for  test_file  in  function_optimizer .test_files .get_by_type (TestType .EXISTING_UNIT_TEST ).test_files 
278+             )
279+             paths_to_cleanup .extend (
280+                 test_file .instrumented_behavior_file_path 
281+                 for  test_file  in  function_optimizer .test_files .get_by_type (TestType .CONCOLIC_COVERAGE_TEST ).test_files 
282+             )
283+ 
284+         if  self .args .benchmark  and  self .replay_tests_dir  and  self .replay_tests_dir .exists ():
285+             paths_to_cleanup .append (self .replay_tests_dir )
286+ 
287+         if  self .test_cfg .concolic_test_root_dir  and  self .test_cfg .concolic_test_root_dir .exists ():
288+             paths_to_cleanup .append (self .test_cfg .concolic_test_root_dir )
289+ 
290+         cleanup_paths (paths_to_cleanup )
291+ 
292+         if  hasattr (get_run_tmp_file , "tmpdir" ):
293+             get_run_tmp_file .tmpdir .cleanup ()
255294
256295
257296def  run_with_args (args : Namespace ) ->  None :
0 commit comments