Skip to content

Commit 1311198

Browse files
committed
extract benchmark runs
1 parent 3006d62 commit 1311198

File tree

2 files changed

+65
-53
lines changed

2 files changed

+65
-53
lines changed

codeflash/optimization/optimizer.py

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,66 @@ def __init__(self, args: Namespace) -> None:
4747
self.functions_checkpoint: CodeflashRunCheckpoint | None = None
4848
self.current_function_optimizer: FunctionOptimizer | None = None
4949

50+
def run_benchmarks(
51+
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
52+
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
53+
"""Run benchmarks for the functions to optimize and collect timing information."""
54+
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] = {}
55+
total_benchmark_timings: dict[BenchmarkKey, float] = {}
56+
57+
if not (hasattr(self.args, "benchmark") and self.args.benchmark and num_optimizable_functions > 0):
58+
return function_benchmark_timings, total_benchmark_timings
59+
60+
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
61+
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
62+
from codeflash.benchmarking.replay_test import generate_replay_test
63+
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
64+
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
65+
from codeflash.code_utils.env_utils import get_pr_number
66+
67+
with progress_bar(
68+
f"Running benchmarks in {self.args.benchmarks_root}", transient=True, revert_to_print=bool(get_pr_number())
69+
):
70+
# Insert decorator
71+
file_path_to_source_code = defaultdict(str)
72+
for file in file_to_funcs_to_optimize:
73+
with file.open("r", encoding="utf8") as f:
74+
file_path_to_source_code[file] = f.read()
75+
try:
76+
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
77+
trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace"
78+
if trace_file.exists():
79+
trace_file.unlink()
80+
81+
self.replay_tests_dir = Path(
82+
tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root)
83+
)
84+
trace_benchmarks_pytest(
85+
self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file
86+
) # Run all tests that use pytest-benchmark
87+
replay_count = generate_replay_test(trace_file, self.replay_tests_dir)
88+
if replay_count == 0:
89+
logger.info(
90+
f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization"
91+
)
92+
else:
93+
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
94+
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
95+
function_to_results = validate_and_format_benchmark_table(
96+
function_benchmark_timings, total_benchmark_timings
97+
)
98+
print_benchmark_table(function_to_results)
99+
except Exception as e:
100+
logger.info(f"Error while tracing existing benchmarks: {e}")
101+
logger.info("Information on existing benchmarks will not be available for this run.")
102+
finally:
103+
# Restore original source code
104+
for file in file_path_to_source_code:
105+
with file.open("w", encoding="utf8") as f:
106+
f.write(file_path_to_source_code[file])
107+
108+
return function_benchmark_timings, total_benchmark_timings
109+
50110
def create_function_optimizer(
51111
self,
52112
function_to_optimize: FunctionToOptimize,
@@ -108,58 +168,9 @@ def run(self) -> None:
108168
module_root=self.args.module_root,
109169
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
110170
)
111-
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {}
112-
total_benchmark_timings: dict[BenchmarkKey, int] = {}
113-
if self.args.benchmark and num_optimizable_functions > 0:
114-
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
115-
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
116-
from codeflash.benchmarking.replay_test import generate_replay_test
117-
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
118-
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
119-
120-
console.rule()
121-
with progress_bar(
122-
f"Running benchmarks in {self.args.benchmarks_root}",
123-
transient=True,
124-
revert_to_print=bool(get_pr_number()),
125-
):
126-
# Insert decorator
127-
file_path_to_source_code = defaultdict(str)
128-
for file in file_to_funcs_to_optimize:
129-
with file.open("r", encoding="utf8") as f:
130-
file_path_to_source_code[file] = f.read()
131-
try:
132-
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
133-
trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace"
134-
if trace_file.exists():
135-
trace_file.unlink()
136-
137-
self.replay_tests_dir = Path(
138-
tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.tests_root)
139-
)
140-
trace_benchmarks_pytest(
141-
self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file
142-
) # Run all tests that use pytest-benchmark
143-
replay_count = generate_replay_test(trace_file, self.replay_tests_dir)
144-
if replay_count == 0:
145-
logger.info(
146-
f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization"
147-
)
148-
else:
149-
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
150-
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
151-
function_to_results = validate_and_format_benchmark_table(
152-
function_benchmark_timings, total_benchmark_timings
153-
)
154-
print_benchmark_table(function_to_results)
155-
except Exception as e:
156-
logger.info(f"Error while tracing existing benchmarks: {e}")
157-
logger.info("Information on existing benchmarks will not be available for this run.")
158-
finally:
159-
# Restore original source code
160-
for file in file_path_to_source_code:
161-
with file.open("w", encoding="utf8") as f:
162-
f.write(file_path_to_source_code[file])
171+
function_benchmark_timings, total_benchmark_timings = self.run_benchmarks(
172+
file_to_funcs_to_optimize, num_optimizable_functions
173+
)
163174
optimizations_found: int = 0
164175
function_iterator_count: int = 0
165176
if self.args.test_framework == "pytest":

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ ignore = [
237237
"S301",
238238
"D104",
239239
"PERF203",
240-
"LOG015"
240+
"LOG015",
241+
"PLC0415"
241242
]
242243

243244
[tool.ruff.lint.flake8-type-checking]

0 commit comments

Comments
 (0)