Skip to content

Commit 57b80ec

Browse files
committed
started implementing group by benchmark
1 parent f34f22f commit 57b80ec

File tree

7 files changed

+56
-36
lines changed

7 files changed

+56
-36
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
101101
qualified_name = f"{module_name}.{function_name}"
102102

103103
# Create the benchmark key (file::function::line)
104-
benchmark_key = f"{benchmark_file}::{benchmark_func}::{benchmark_line}"
105-
benchmark_key = BenchmarkKey(file_name=benchmark_file, function_name=benchmark_func, line_number=benchmark_line)
104+
benchmark_key = BenchmarkKey(file_name=benchmark_file, function_name=benchmark_func)
106105
# Initialize the inner dictionary if needed
107106
if qualified_name not in result:
108107
result[qualified_name] = {}
@@ -152,8 +151,7 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
152151
# Process overhead information
153152
for row in cursor.fetchall():
154153
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
155-
benchmark_key = f"{benchmark_file}::{benchmark_func}::{benchmark_line}"
156-
benchmark_key = BenchmarkKey(file_name=benchmark_file, function_name=benchmark_func, line_number=benchmark_line)
154+
benchmark_key = BenchmarkKey(file_name=benchmark_file, function_name=benchmark_func)
157155
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
158156

159157
# Query the benchmark_timings table for total times
@@ -167,8 +165,7 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
167165
benchmark_file, benchmark_func, benchmark_line, time_ns = row
168166

169167
# Create the benchmark key (file::function::line)
170-
benchmark_key = f"{benchmark_file}::{benchmark_func}::{benchmark_line}"
171-
benchmark_key = BenchmarkKey(file_name=benchmark_file, function_name=benchmark_func, line_number=benchmark_line)
168+
benchmark_key = BenchmarkKey(file_name=benchmark_file, function_name=benchmark_func)
172169
# Subtract overhead from total time
173170
overhead = overhead_by_benchmark.get(benchmark_key, 0)
174171
result[benchmark_key] = time_ns - overhead
@@ -239,7 +236,7 @@ def test_something(benchmark):
239236
The return value of the function
240237
241238
"""
242-
benchmark_file_name = self.request.node.fspath.basename
239+
benchmark_file_name = self.request.node.fspath
243240
benchmark_function_name = self.request.node.name
244241
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
245242

codeflash/benchmarking/replay_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,11 @@ def create_trace_replay_test_code(
196196
return imports + "\n" + metadata + "\n" + test_template
197197

198198
def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100) -> int:
199-
"""Generate multiple replay tests from the traced function calls, grouping by benchmark name.
199+
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
200200
201201
Args:
202202
trace_file_path: Path to the SQLite database file
203203
output_dir: Directory to write the generated tests (if None, only returns the code)
204-
project_root: Root directory of the project for module imports
205204
test_framework: 'pytest' or 'unittest'
206205
max_run_count: Maximum number of runs to include per function
207206
@@ -267,7 +266,7 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
267266
# Write to file if requested
268267
if output_dir:
269268
output_file = get_test_file_path(
270-
test_dir=Path(output_dir), function_name=f"{benchmark_file_name[5:]}_{benchmark_function_name}", test_type="replay"
269+
test_dir=Path(output_dir), function_name=f"{benchmark_file_name}_{benchmark_function_name}", test_type="replay"
271270
)
272271
# Write test code to file, parents = true
273272
output_dir.mkdir(parents=True, exist_ok=True)

codeflash/benchmarking/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def process_benchmark_data(
9393

9494
for benchmark_key, og_benchmark_timing in fto_benchmark_timings.items():
9595
try:
96-
benchmark_file_name, benchmark_test_function, line_number = benchmark_key.split("::")
96+
benchmark_file_name, benchmark_test_function = benchmark_key.split("::")
9797
except ValueError:
9898
continue # Skip malformed benchmark keys
9999

codeflash/models/models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,9 @@ class BestOptimization(BaseModel):
8282
class BenchmarkKey:
8383
file_name: str
8484
function_name: str
85-
line_number: int
8685

8786
def __str__(self) -> str:
88-
return f"{self.file_name}::{self.function_name}::{self.line_number}"
87+
return f"{self.file_name}::{self.function_name}"
8988

9089
@dataclass
9190
class BenchmarkDetail:
@@ -270,7 +269,7 @@ class FunctionParent:
270269
class OriginalCodeBaseline(BaseModel):
271270
behavioral_test_results: TestResults
272271
benchmarking_test_results: TestResults
273-
replay_benchmarking_test_results: Optional[TestResults] = None
272+
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
274273
runtime: int
275274
coverage_results: Optional[CoverageData]
276275

codeflash/optimization/function_optimizer.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def __init__(
8888
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
8989
function_to_optimize_ast: ast.FunctionDef | None = None,
9090
aiservice_client: AiServiceClient | None = None,
91-
function_benchmark_timings: dict[str, int] | None = None,
92-
total_benchmark_timings: dict[str, int] | None = None,
91+
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
92+
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
9393
args: Namespace | None = None,
9494
) -> None:
9595
self.project_root = test_cfg.project_root_path
@@ -428,20 +428,24 @@ def determine_best_candidate(
428428
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
429429
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
430430
if self.args.benchmark:
431-
original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results.total_passed_runtime()
432-
candidate_replay_runtime = candidate_result.replay_benchmarking_test_results.total_passed_runtime()
433-
replay_perf_gain = performance_gain(
434-
original_runtime_ns=original_code_replay_runtime,
435-
optimized_runtime_ns=candidate_replay_runtime,
436-
)
437-
tree.add(f"Original benchmark replay runtime: {humanize_runtime(original_code_replay_runtime)}")
438-
tree.add(
439-
f"Best benchmark replay runtime: {humanize_runtime(candidate_replay_runtime)} "
440-
f"(measured over {candidate_result.max_loop_count} "
441-
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
442-
)
443-
tree.add(f"Speedup percentage for benchmark replay test: {replay_perf_gain * 100:.1f}%")
444-
tree.add(f"Speedup ratio for benchmark replay test: {replay_perf_gain + 1:.1f}X")
431+
432+
benchmark_keys = {(benchmark.file_name, benchmark.function_name) for benchmark in self.total_benchmark_timings}
433+
test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmark(benchmark_keys)
434+
for benchmark_key, test_results in test_results_by_benchmark.items():
435+
original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[benchmark_key].total_passed_runtime()
436+
candidate_replay_runtime = candidate_result.replay_benchmarking_test_results.total_passed_runtime()
437+
replay_perf_gain = performance_gain(
438+
original_runtime_ns=original_code_replay_runtime,
439+
optimized_runtime_ns=candidate_replay_runtime,
440+
)
441+
tree.add(f"Original benchmark replay runtime: {humanize_runtime(original_code_replay_runtime)}")
442+
tree.add(
443+
f"Best benchmark replay runtime: {humanize_runtime(candidate_replay_runtime)} "
444+
f"(measured over {candidate_result.max_loop_count} "
445+
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
446+
)
447+
tree.add(f"Speedup percentage for benchmark replay test: {replay_perf_gain * 100:.1f}%")
448+
tree.add(f"Speedup ratio for benchmark replay test: {replay_perf_gain + 1:.1f}X")
445449
best_optimization = BestOptimization(
446450
candidate=candidate,
447451
helper_functions=code_context.helper_functions,
@@ -898,7 +902,7 @@ def establish_original_code_baseline(
898902
logger.debug(f"Total original code runtime (ns): {total_timing}")
899903

900904
if self.args.benchmark:
901-
replay_benchmarking_test_results = benchmarking_results.filter(TestType.REPLAY_TEST)
905+
replay_benchmarking_test_results = benchmarking_results.filter_by_test_type(TestType.REPLAY_TEST)
902906
logger.info(f"Total replay test runtime: {humanize_runtime(replay_benchmarking_test_results.total_passed_runtime())}")
903907
return Success(
904908
(
@@ -1020,7 +1024,7 @@ def run_optimized_candidate(
10201024

10211025
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
10221026
if self.args.benchmark:
1023-
candidate_replay_benchmarking_results = candidate_benchmarking_results.filter(TestType.REPLAY_TEST)
1027+
candidate_replay_benchmarking_results = candidate_benchmarking_results.filter_by_test_type(TestType.REPLAY_TEST)
10241028
logger.debug(
10251029
f"Total optimized code {optimization_candidate_index} replay benchmark runtime (ns): {candidate_replay_benchmarking_results.total_passed_runtime()}"
10261030
)

codeflash/optimization/optimizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def create_function_optimizer(
6060
function_to_optimize_ast: ast.FunctionDef | None = None,
6161
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
6262
function_to_optimize_source_code: str | None = "",
63-
function_benchmark_timings: dict[str, dict[str, float]] | None = None,
64-
total_benchmark_timings: dict[str, float] | None = None,
63+
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
64+
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
6565
) -> FunctionOptimizer:
6666
return FunctionOptimizer(
6767
function_to_optimize=function_to_optimize,
@@ -111,7 +111,10 @@ def run(self) -> None:
111111
try:
112112
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
113113
trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace"
114-
replay_tests_dir = Path(self.args.tests_root) / "codeflash_replay_tests"
114+
if trace_file.exists():
115+
trace_file.unlink()
116+
117+
replay_tests_dir = Path(self.args.tests_root)
115118
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Run all tests that use pytest-benchmark
116119
replay_count = generate_replay_test(trace_file, replay_tests_dir)
117120
if replay_count == 0:

codeflash/verification/test_results.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def merge(self, other: TestResults) -> None:
125125
raise ValueError(msg)
126126
self.test_result_idx[k] = v + original_len
127127

128-
def filter(self, test_type: TestType) -> TestResults:
128+
def filter_by_test_type(self, test_type: TestType) -> TestResults:
129129
filtered_test_results = []
130130
filtered_test_results_idx = {}
131131
for test_result in self.test_results:
@@ -134,6 +134,24 @@ def filter(self, test_type: TestType) -> TestResults:
134134
filtered_test_results.append(test_result)
135135
return TestResults(test_results=filtered_test_results, test_result_idx=filtered_test_results_idx)
136136

137+
def group_by_benchmark(self, benchmark_key_set:set[tuple[str,str]]) -> dict[tuple[str,str],TestResults]:
138+
"""Group TestResults by benchmark key.
139+
140+
For now, use a tuple of (file_path, function_name) as the benchmark key. Can't import BenchmarkKey because of circular import.
141+
142+
Args:
143+
benchmark_key_set (set[tuple[str,str]]): A set of tuples of (file_path, function_name)
144+
145+
Returns:
146+
TestResults: A new TestResults object with the test results grouped by benchmark key.
147+
148+
"""
149+
test_result_by_benchmark = defaultdict(TestResults)
150+
for test_result in self.test_results:
151+
if test_result.test_type == TestType.REPLAY_TEST and (test_result.id.test_module_path,test_result.id.test_function_name) in benchmark_key_set:
152+
test_result_by_benchmark[(test_result.id.test_module_path,test_result.id.test_function_name)].add(test_result)
153+
return test_result_by_benchmark
154+
137155
def get_by_unique_invocation_loop_id(self, unique_invocation_loop_id: str) -> FunctionTestInvocation | None:
138156
try:
139157
return self.test_results[self.test_result_idx[unique_invocation_loop_id]]

0 commit comments

Comments
 (0)