22
33import ast
44import os
5- import shutil
65import tempfile
76import time
87from collections import defaultdict
1817from codeflash .cli_cmds .console import console , logger , progress_bar
1918from codeflash .code_utils import env_utils
2019from codeflash .code_utils .code_replacer import normalize_code , normalize_node
21- from codeflash .code_utils .code_utils import get_run_tmp_file
20+ from codeflash .code_utils .code_utils import cleanup_paths , get_run_tmp_file
2221from codeflash .code_utils .static_analysis import analyze_imported_modules , get_first_top_level_function_or_method_ast
2322from codeflash .discovery .discover_unit_tests import discover_unit_tests
2423from codeflash .discovery .functions_to_optimize import get_functions_to_optimize
@@ -52,6 +51,11 @@ def __init__(self, args: Namespace) -> None:
5251 self .experiment_id = os .getenv ("CODEFLASH_EXPERIMENT_ID" , None )
5352 self .local_aiservice_client = LocalAiServiceClient () if self .experiment_id else None
5453 self .replay_tests_dir = None
54+
55+ self .test_cfg .concolic_test_root_dir = Path (
56+ tempfile .mkdtemp (dir = self .args .tests_root , prefix = "codeflash_concolic_" )
57+ )
58+
5559 def create_function_optimizer (
5660 self ,
5761 function_to_optimize : FunctionToOptimize ,
@@ -71,7 +75,7 @@ def create_function_optimizer(
7175 args = self .args ,
7276 function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else None ,
7377 total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else None ,
74- replay_tests_dir = self .replay_tests_dir
78+ replay_tests_dir = self .replay_tests_dir ,
7579 )
7680
7781 def run (self ) -> None :
@@ -81,6 +85,7 @@ def run(self) -> None:
8185 if not env_utils .ensure_codeflash_api_key ():
8286 return
8387 function_optimizer = None
88+ trace_file = None
8489 file_to_funcs_to_optimize : dict [Path , list [FunctionToOptimize ]]
8590 num_optimizable_functions : int
8691
@@ -98,10 +103,7 @@ def run(self) -> None:
98103 function_benchmark_timings : dict [str , dict [BenchmarkKey , int ]] = {}
99104 total_benchmark_timings : dict [BenchmarkKey , int ] = {}
100105 if self .args .benchmark and num_optimizable_functions > 0 :
101- with progress_bar (
102- f"Running benchmarks in { self .args .benchmarks_root } " ,
103- transient = True ,
104- ):
106+ with progress_bar (f"Running benchmarks in { self .args .benchmarks_root } " , transient = True ):
105107 # Insert decorator
106108 file_path_to_source_code = defaultdict (str )
107109 for file in file_to_funcs_to_optimize :
@@ -113,15 +115,23 @@ def run(self) -> None:
113115 if trace_file .exists ():
114116 trace_file .unlink ()
115117
116- self .replay_tests_dir = Path (tempfile .mkdtemp (prefix = "codeflash_replay_tests_" , dir = self .args .benchmarks_root ))
117- trace_benchmarks_pytest (self .args .benchmarks_root , self .args .tests_root , self .args .project_root , trace_file ) # Run all tests that use pytest-benchmark
118+ self .replay_tests_dir = Path (
119+ tempfile .mkdtemp (prefix = "codeflash_replay_tests_" , dir = self .args .benchmarks_root )
120+ )
121+ trace_benchmarks_pytest (
122+ self .args .benchmarks_root , self .args .tests_root , self .args .project_root , trace_file
123+ ) # Run all tests that use pytest-benchmark
118124 replay_count = generate_replay_test (trace_file , self .replay_tests_dir )
119125 if replay_count == 0 :
120- logger .info (f"No valid benchmarks found in { self .args .benchmarks_root } for functions to optimize, continuing optimization" )
126+ logger .info (
127+ f"No valid benchmarks found in { self .args .benchmarks_root } for functions to optimize, continuing optimization"
128+ )
121129 else :
122130 function_benchmark_timings = CodeFlashBenchmarkPlugin .get_function_benchmark_timings (trace_file )
123131 total_benchmark_timings = CodeFlashBenchmarkPlugin .get_benchmark_timings (trace_file )
124- function_to_results = validate_and_format_benchmark_table (function_benchmark_timings , total_benchmark_timings )
132+ function_to_results = validate_and_format_benchmark_table (
133+ function_benchmark_timings , total_benchmark_timings
134+ )
125135 print_benchmark_table (function_to_results )
126136 except Exception as e :
127137 logger .info (f"Error while tracing existing benchmarks: { e } " )
@@ -131,12 +141,9 @@ def run(self) -> None:
131141 for file in file_path_to_source_code :
132142 with file .open ("w" , encoding = "utf8" ) as f :
133143 f .write (file_path_to_source_code [file ])
144+ self .cleanup ()
134145 optimizations_found : int = 0
135146 function_iterator_count : int = 0
136- if self .args .test_framework == "pytest" :
137- self .test_cfg .concolic_test_root_dir = Path (
138- tempfile .mkdtemp (dir = self .args .tests_root , prefix = "codeflash_concolic_" )
139- )
140147 try :
141148 ph ("cli-optimize-functions-to-optimize" , {"num_functions" : num_optimizable_functions })
142149 if num_optimizable_functions == 0 :
@@ -148,11 +155,12 @@ def run(self) -> None:
148155 function_to_tests : dict [str , list [FunctionCalledInTest ]] = discover_unit_tests (self .test_cfg )
149156 num_discovered_tests : int = sum ([len (value ) for value in function_to_tests .values ()])
150157 console .rule ()
151- logger .info (f"Discovered { num_discovered_tests } existing unit tests in { (time .time () - start_time ):.1f} s at { self .test_cfg .tests_root } " )
158+ logger .info (
159+ f"Discovered { num_discovered_tests } existing unit tests in { (time .time () - start_time ):.1f} s at { self .test_cfg .tests_root } "
160+ )
152161 console .rule ()
153162 ph ("cli-optimize-discovered-tests" , {"num_tests" : num_discovered_tests })
154163
155-
156164 for original_module_path in file_to_funcs_to_optimize :
157165 logger .info (f"Examining file { original_module_path !s} …" )
158166 console .rule ()
@@ -212,14 +220,26 @@ def run(self) -> None:
212220 qualified_name_w_module = function_to_optimize .qualified_name_with_modules_from_root (
213221 self .args .project_root
214222 )
215- if self .args .benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings :
223+ if (
224+ self .args .benchmark
225+ and function_benchmark_timings
226+ and qualified_name_w_module in function_benchmark_timings
227+ and total_benchmark_timings
228+ ):
216229 function_optimizer = self .create_function_optimizer (
217- function_to_optimize , function_to_optimize_ast , function_to_tests , validated_original_code [original_module_path ].source_code , function_benchmark_timings [qualified_name_w_module ], total_benchmark_timings
230+ function_to_optimize ,
231+ function_to_optimize_ast ,
232+ function_to_tests ,
233+ validated_original_code [original_module_path ].source_code ,
234+ function_benchmark_timings [qualified_name_w_module ],
235+ total_benchmark_timings ,
218236 )
219237 else :
220238 function_optimizer = self .create_function_optimizer (
221- function_to_optimize , function_to_optimize_ast , function_to_tests ,
222- validated_original_code [original_module_path ].source_code
239+ function_to_optimize ,
240+ function_to_optimize_ast ,
241+ function_to_tests ,
242+ validated_original_code [original_module_path ].source_code ,
223243 )
224244
225245 best_optimization = function_optimizer .optimize_function ()
@@ -235,23 +255,44 @@ def run(self) -> None:
235255 elif self .args .all :
236256 logger .info ("✨ All functions have been optimized! ✨" )
237257 finally :
238- if function_optimizer :
239- for test_file in function_optimizer .test_files .get_by_type (TestType .GENERATED_REGRESSION ).test_files :
240- test_file .instrumented_behavior_file_path .unlink (missing_ok = True )
241- test_file .benchmarking_file_path .unlink (missing_ok = True )
242- for test_file in function_optimizer .test_files .get_by_type (TestType .EXISTING_UNIT_TEST ).test_files :
243- test_file .instrumented_behavior_file_path .unlink (missing_ok = True )
244- test_file .benchmarking_file_path .unlink (missing_ok = True )
245- for test_file in function_optimizer .test_files .get_by_type (TestType .CONCOLIC_COVERAGE_TEST ).test_files :
246- test_file .instrumented_behavior_file_path .unlink (missing_ok = True )
247- if function_optimizer .test_cfg .concolic_test_root_dir :
248- shutil .rmtree (function_optimizer .test_cfg .concolic_test_root_dir , ignore_errors = True )
249- if self .args .benchmark :
250- if self .replay_tests_dir .exists ():
251- shutil .rmtree (self .replay_tests_dir , ignore_errors = True )
252- trace_file .unlink (missing_ok = True )
253- if hasattr (get_run_tmp_file , "tmpdir" ):
254- get_run_tmp_file .tmpdir .cleanup ()
258+ self .cleanup (function_optimizer = function_optimizer )
259+
260+ def cleanup (self , function_optimizer : FunctionOptimizer | None = None ) -> None :
261+ paths_to_cleanup : list [Path ] = []
262+ if function_optimizer :
263+ paths_to_cleanup .extend (
264+ test_file .instrumented_behavior_file_path
265+ for test_file in function_optimizer .test_files .get_by_type (TestType .GENERATED_REGRESSION ).test_files
266+ )
267+ paths_to_cleanup .extend (
268+ test_file .benchmarking_file_path
269+ for test_file in function_optimizer .test_files .get_by_type (TestType .GENERATED_REGRESSION ).test_files
270+ )
271+ paths_to_cleanup .extend (
272+ test_file .instrumented_behavior_file_path
273+ for test_file in function_optimizer .test_files .get_by_type (TestType .EXISTING_UNIT_TEST ).test_files
274+ )
275+ paths_to_cleanup .extend (
276+ test_file .benchmarking_file_path
277+ for test_file in function_optimizer .test_files .get_by_type (TestType .EXISTING_UNIT_TEST ).test_files
278+ )
279+ paths_to_cleanup .extend (
280+ test_file .instrumented_behavior_file_path
281+ for test_file in function_optimizer .test_files .get_by_type (TestType .CONCOLIC_COVERAGE_TEST ).test_files
282+ )
283+ paths_to_cleanup .extend (
284+ test_file .benchmarking_file_path
285+ for test_file in function_optimizer .test_files .get_by_type (TestType .REPLAY_TEST ).test_files
286+ )
287+
288+ paths_to_cleanup .extend (
289+ path for path in {self .replay_tests_dir , self .test_cfg .concolic_test_root_dir } if path and path .exists ()
290+ )
291+
292+ cleanup_paths (paths_to_cleanup )
293+
294+ if hasattr (get_run_tmp_file , "tmpdir" ):
295+ get_run_tmp_file .tmpdir .cleanup ()
255296
256297
257298def run_with_args (args : Namespace ) -> None :
0 commit comments