Skip to content

Commit 8193ea6

Browse files
committed
do proper cleanup
1 parent cb64c68 commit 8193ea6

File tree

2 files changed

+82
-39
lines changed

2 files changed

+82
-39
lines changed

codeflash/code_utils/code_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import os
5+
import shutil
56
import site
67
from functools import lru_cache
78
from pathlib import Path
@@ -118,4 +119,7 @@ def has_any_async_functions(code: str) -> bool:
118119

119120
def cleanup_paths(paths: list[Path]) -> None:
120121
for path in paths:
121-
path.unlink(missing_ok=True)
122+
if path.is_dir():
123+
shutil.rmtree(path, ignore_errors=True)
124+
else:
125+
path.unlink(missing_ok=True)

codeflash/optimization/optimizer.py

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import ast
44
import os
5-
import shutil
65
import tempfile
76
import time
87
from collections import defaultdict
@@ -18,7 +17,7 @@
1817
from codeflash.cli_cmds.console import console, logger, progress_bar
1918
from codeflash.code_utils import env_utils
2019
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
21-
from codeflash.code_utils.code_utils import get_run_tmp_file
20+
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
2221
from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast
2322
from codeflash.discovery.discover_unit_tests import discover_unit_tests
2423
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
@@ -52,6 +51,11 @@ def __init__(self, args: Namespace) -> None:
5251
self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None)
5352
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
5453
self.replay_tests_dir = None
54+
if self.args.test_framework == "pytest":
55+
self.test_cfg.concolic_test_root_dir = Path(
56+
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
57+
)
58+
5559
def create_function_optimizer(
5660
self,
5761
function_to_optimize: FunctionToOptimize,
@@ -71,7 +75,7 @@ def create_function_optimizer(
7175
args=self.args,
7276
function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None,
7377
total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None,
74-
replay_tests_dir = self.replay_tests_dir
78+
replay_tests_dir=self.replay_tests_dir,
7579
)
7680

7781
def run(self) -> None:
@@ -81,6 +85,7 @@ def run(self) -> None:
8185
if not env_utils.ensure_codeflash_api_key():
8286
return
8387
function_optimizer = None
88+
trace_file = None
8489
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
8590
num_optimizable_functions: int
8691

@@ -98,10 +103,7 @@ def run(self) -> None:
98103
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {}
99104
total_benchmark_timings: dict[BenchmarkKey, int] = {}
100105
if self.args.benchmark and num_optimizable_functions > 0:
101-
with progress_bar(
102-
f"Running benchmarks in {self.args.benchmarks_root}",
103-
transient=True,
104-
):
106+
with progress_bar(f"Running benchmarks in {self.args.benchmarks_root}", transient=True):
105107
# Insert decorator
106108
file_path_to_source_code = defaultdict(str)
107109
for file in file_to_funcs_to_optimize:
@@ -113,30 +115,35 @@ def run(self) -> None:
113115
if trace_file.exists():
114116
trace_file.unlink()
115117

116-
self.replay_tests_dir = Path(tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root))
117-
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Run all tests that use pytest-benchmark
118+
self.replay_tests_dir = Path(
119+
tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root)
120+
)
121+
trace_benchmarks_pytest(
122+
self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file
123+
) # Run all tests that use pytest-benchmark
118124
replay_count = generate_replay_test(trace_file, self.replay_tests_dir)
119125
if replay_count == 0:
120-
logger.info(f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization")
126+
logger.info(
127+
f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization"
128+
)
121129
else:
122130
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
123131
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
124-
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
132+
function_to_results = validate_and_format_benchmark_table(
133+
function_benchmark_timings, total_benchmark_timings
134+
)
125135
print_benchmark_table(function_to_results)
126136
except Exception as e:
127137
logger.info(f"Error while tracing existing benchmarks: {e}")
128138
logger.info("Information on existing benchmarks will not be available for this run.")
139+
self.cleanup(function_optimizer=None)
129140
finally:
130141
# Restore original source code
131142
for file in file_path_to_source_code:
132143
with file.open("w", encoding="utf8") as f:
133144
f.write(file_path_to_source_code[file])
134145
optimizations_found: int = 0
135146
function_iterator_count: int = 0
136-
if self.args.test_framework == "pytest":
137-
self.test_cfg.concolic_test_root_dir = Path(
138-
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
139-
)
140147
try:
141148
ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions})
142149
if num_optimizable_functions == 0:
@@ -148,11 +155,12 @@ def run(self) -> None:
148155
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
149156
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
150157
console.rule()
151-
logger.info(f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}")
158+
logger.info(
159+
f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
160+
)
152161
console.rule()
153162
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
154163

155-
156164
for original_module_path in file_to_funcs_to_optimize:
157165
logger.info(f"Examining file {original_module_path!s}…")
158166
console.rule()
@@ -212,14 +220,26 @@ def run(self) -> None:
212220
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(
213221
self.args.project_root
214222
)
215-
if self.args.benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings:
223+
if (
224+
self.args.benchmark
225+
and function_benchmark_timings
226+
and qualified_name_w_module in function_benchmark_timings
227+
and total_benchmark_timings
228+
):
216229
function_optimizer = self.create_function_optimizer(
217-
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings[qualified_name_w_module], total_benchmark_timings
230+
function_to_optimize,
231+
function_to_optimize_ast,
232+
function_to_tests,
233+
validated_original_code[original_module_path].source_code,
234+
function_benchmark_timings[qualified_name_w_module],
235+
total_benchmark_timings,
218236
)
219237
else:
220238
function_optimizer = self.create_function_optimizer(
221-
function_to_optimize, function_to_optimize_ast, function_to_tests,
222-
validated_original_code[original_module_path].source_code
239+
function_to_optimize,
240+
function_to_optimize_ast,
241+
function_to_tests,
242+
validated_original_code[original_module_path].source_code,
223243
)
224244

225245
best_optimization = function_optimizer.optimize_function()
@@ -235,23 +255,42 @@ def run(self) -> None:
235255
elif self.args.all:
236256
logger.info("✨ All functions have been optimized! ✨")
237257
finally:
238-
if function_optimizer:
239-
for test_file in function_optimizer.test_files.get_by_type(TestType.GENERATED_REGRESSION).test_files:
240-
test_file.instrumented_behavior_file_path.unlink(missing_ok=True)
241-
test_file.benchmarking_file_path.unlink(missing_ok=True)
242-
for test_file in function_optimizer.test_files.get_by_type(TestType.EXISTING_UNIT_TEST).test_files:
243-
test_file.instrumented_behavior_file_path.unlink(missing_ok=True)
244-
test_file.benchmarking_file_path.unlink(missing_ok=True)
245-
for test_file in function_optimizer.test_files.get_by_type(TestType.CONCOLIC_COVERAGE_TEST).test_files:
246-
test_file.instrumented_behavior_file_path.unlink(missing_ok=True)
247-
if function_optimizer.test_cfg.concolic_test_root_dir:
248-
shutil.rmtree(function_optimizer.test_cfg.concolic_test_root_dir, ignore_errors=True)
249-
if self.args.benchmark:
250-
if self.replay_tests_dir.exists():
251-
shutil.rmtree(self.replay_tests_dir, ignore_errors=True)
252-
trace_file.unlink(missing_ok=True)
253-
if hasattr(get_run_tmp_file, "tmpdir"):
254-
get_run_tmp_file.tmpdir.cleanup()
258+
self.cleanup(function_optimizer=function_optimizer)
259+
260+
def cleanup(self, function_optimizer: FunctionOptimizer | None) -> None:
261+
paths_to_cleanup: list[Path] = []
262+
if function_optimizer:
263+
paths_to_cleanup.extend(
264+
test_file.instrumented_behavior_file_path
265+
for test_file in function_optimizer.test_files.get_by_type(TestType.GENERATED_REGRESSION).test_files
266+
)
267+
paths_to_cleanup.extend(
268+
test_file.benchmarking_file_path
269+
for test_file in function_optimizer.test_files.get_by_type(TestType.GENERATED_REGRESSION).test_files
270+
)
271+
paths_to_cleanup.extend(
272+
test_file.instrumented_behavior_file_path
273+
for test_file in function_optimizer.test_files.get_by_type(TestType.EXISTING_UNIT_TEST).test_files
274+
)
275+
paths_to_cleanup.extend(
276+
test_file.benchmarking_file_path
277+
for test_file in function_optimizer.test_files.get_by_type(TestType.EXISTING_UNIT_TEST).test_files
278+
)
279+
paths_to_cleanup.extend(
280+
test_file.instrumented_behavior_file_path
281+
for test_file in function_optimizer.test_files.get_by_type(TestType.CONCOLIC_COVERAGE_TEST).test_files
282+
)
283+
284+
if self.args.benchmark and self.replay_tests_dir and self.replay_tests_dir.exists():
285+
paths_to_cleanup.append(self.replay_tests_dir)
286+
287+
if self.test_cfg.concolic_test_root_dir and self.test_cfg.concolic_test_root_dir.exists():
288+
paths_to_cleanup.append(self.test_cfg.concolic_test_root_dir)
289+
290+
cleanup_paths(paths_to_cleanup)
291+
292+
if hasattr(get_run_tmp_file, "tmpdir"):
293+
get_run_tmp_file.tmpdir.cleanup()
255294

256295

257296
def run_with_args(args: Namespace) -> None:

0 commit comments

Comments
 (0)