Skip to content

Commit 9be93b6

Browse files
Merge pull request #294 from codeflash-ai/add-timing-info-to-generated-tests
Add the test case timing to the generated test [CF-482]
2 parents 1eb54df + 84415b7 commit 9be93b6

21 files changed

+720
-49
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def setup(self, trace_path: str) -> None:
2525
"""Set up the database connection for direct writing.
2626
2727
Args:
28+
----
2829
trace_path: Path to the trace database file
2930
3031
"""
@@ -52,6 +53,7 @@ def write_function_timings(self) -> None:
5253
"""Write function call data directly to the database.
5354
5455
Args:
56+
----
5557
data: List of function call data tuples to write
5658
5759
"""
@@ -94,9 +96,11 @@ def __call__(self, func: Callable) -> Callable:
9496
"""Use as a decorator to trace function execution.
9597
9698
Args:
99+
----
97100
func: The function to be decorated
98101
99102
Returns:
103+
-------
100104
The wrapped function
101105
102106
"""

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,12 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
7676
"""Add codeflash_trace to a function.
7777
7878
Args:
79+
----
7980
code: The source code as a string
8081
functions_to_optimize: List of FunctionToOptimize instances containing function details
8182
8283
Returns:
84+
-------
8385
The modified source code as a string
8486
8587
"""

codeflash/benchmarking/plugin/plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
7474
"""Process the trace file and extract timing data for all functions.
7575
7676
Args:
77+
----
7778
trace_path: Path to the trace file
7879
7980
Returns:
81+
-------
8082
A nested dictionary where:
8183
- Outer keys are module_name.qualified_name (module.class.function)
8284
- Inner keys are of type BenchmarkKey
@@ -132,9 +134,11 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
132134
"""Extract total benchmark timings from trace files.
133135
134136
Args:
137+
----
135138
trace_path: Path to the trace file
136139
137140
Returns:
141+
-------
138142
A dictionary mapping where:
139143
- Keys are of type BenchmarkKey
140144
- Values are total benchmark timing in milliseconds (with overhead subtracted)

codeflash/benchmarking/replay_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def create_trace_replay_test_code(
5555
"""Create a replay test for functions based on trace data.
5656
5757
Args:
58+
----
5859
trace_file: Path to the SQLite database file
5960
functions_data: List of dictionaries with function info extracted from DB
6061
test_framework: 'pytest' or 'unittest'
6162
max_run_count: Maximum number of runs to include in the test
6263
6364
Returns:
65+
-------
6466
A string containing the test code
6567
6668
"""
@@ -218,12 +220,14 @@ def generate_replay_test(
218220
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
219221
220222
Args:
223+
----
221224
trace_file_path: Path to the SQLite database file
222225
output_dir: Directory to write the generated tests (if None, only returns the code)
223226
test_framework: 'pytest' or 'unittest'
224227
max_run_count: Maximum number of runs to include per function
225228
226229
Returns:
230+
-------
227231
Dictionary mapping benchmark names to generated test code
228232
229233
"""

codeflash/benchmarking/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@ def process_benchmark_data(
8383
"""Process benchmark data and generate detailed benchmark information.
8484
8585
Args:
86+
----
8687
replay_performance_gain: The performance gain from replay
8788
fto_benchmark_timings: Function to optimize benchmark timings
8889
total_benchmark_timings: Total benchmark timings
8990
9091
Returns:
92+
-------
9193
ProcessedBenchmarkInfo containing processed benchmark details
9294
9395
"""

codeflash/cli_cmds/logging_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
2727
],
2828
force=True,
2929
)
30-
logging.info("Verbose DEBUG logging enabled") # noqa: LOG015
30+
logging.info("Verbose DEBUG logging enabled")
3131
else:
32-
logging.info("Logging level set to INFO") # noqa: LOG015
32+
logging.info("Logging level set to INFO")
3333
console.rule()

codeflash/code_utils/checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def add_function_to_checkpoint(
4747
"""Add a function to the checkpoint after it has been processed.
4848
4949
Args:
50+
----
5051
function_fully_qualified_name: The fully qualified name of the function
5152
status: Status of optimization (e.g., "optimized", "failed", "skipped")
5253
additional_info: Any additional information to store about the function
@@ -104,7 +105,8 @@ def cleanup(self) -> None:
104105
def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]:
105106
"""Get information about all processed functions, regardless of status.
106107
107-
Returns:
108+
Returns
109+
-------
108110
Dictionary mapping function names to their processing information
109111
110112
"""
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import re
2+
3+
import libcst as cst
4+
5+
from codeflash.cli_cmds.console import logger
6+
from codeflash.code_utils.time_utils import format_time
7+
from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults
8+
9+
10+
def remove_functions_from_generated_tests(
11+
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
12+
) -> GeneratedTestsList:
13+
new_generated_tests = []
14+
for generated_test in generated_tests.generated_tests:
15+
for test_function in test_functions_to_remove:
16+
function_pattern = re.compile(
17+
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
18+
re.DOTALL,
19+
)
20+
21+
match = function_pattern.search(generated_test.generated_original_test_source)
22+
23+
if match is None or "@pytest.mark.parametrize" in match.group(0):
24+
continue
25+
26+
generated_test.generated_original_test_source = function_pattern.sub(
27+
"", generated_test.generated_original_test_source
28+
)
29+
30+
new_generated_tests.append(generated_test)
31+
32+
return GeneratedTestsList(generated_tests=new_generated_tests)
33+
34+
35+
def add_runtime_comments_to_generated_tests(
36+
generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults
37+
) -> GeneratedTestsList:
38+
"""Add runtime performance comments to function calls in generated tests."""
39+
# Create dictionaries for fast lookup of runtime data
40+
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
41+
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
42+
43+
class RuntimeCommentTransformer(cst.CSTTransformer):
44+
def __init__(self) -> None:
45+
self.in_test_function = False
46+
self.current_test_name: str | None = None
47+
48+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
49+
if node.name.value.startswith("test_"):
50+
self.in_test_function = True
51+
self.current_test_name = node.name.value
52+
else:
53+
self.in_test_function = False
54+
self.current_test_name = None
55+
56+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
57+
if original_node.name.value.startswith("test_"):
58+
self.in_test_function = False
59+
self.current_test_name = None
60+
return updated_node
61+
62+
def leave_SimpleStatementLine(
63+
self,
64+
original_node: cst.SimpleStatementLine, # noqa: ARG002
65+
updated_node: cst.SimpleStatementLine,
66+
) -> cst.SimpleStatementLine:
67+
if not self.in_test_function or not self.current_test_name:
68+
return updated_node
69+
70+
# Look for assignment statements that assign to codeflash_output
71+
# Handle both single statements and multiple statements on one line
72+
codeflash_assignment_found = False
73+
for stmt in updated_node.body:
74+
if isinstance(stmt, cst.Assign) and (
75+
len(stmt.targets) == 1
76+
and isinstance(stmt.targets[0].target, cst.Name)
77+
and stmt.targets[0].target.value == "codeflash_output"
78+
):
79+
codeflash_assignment_found = True
80+
break
81+
82+
if codeflash_assignment_found:
83+
# Find matching test cases by looking for this test function name in the test results
84+
matching_original_times = []
85+
matching_optimized_times = []
86+
87+
for invocation_id, runtimes in original_runtime_by_test.items():
88+
if invocation_id.test_function_name == self.current_test_name:
89+
matching_original_times.extend(runtimes)
90+
91+
for invocation_id, runtimes in optimized_runtime_by_test.items():
92+
if invocation_id.test_function_name == self.current_test_name:
93+
matching_optimized_times.extend(runtimes)
94+
95+
if matching_original_times and matching_optimized_times:
96+
original_time = min(matching_original_times)
97+
optimized_time = min(matching_optimized_times)
98+
99+
# Create the runtime comment
100+
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"
101+
102+
# Add comment to the trailing whitespace
103+
new_trailing_whitespace = cst.TrailingWhitespace(
104+
whitespace=cst.SimpleWhitespace(" "),
105+
comment=cst.Comment(comment_text),
106+
newline=updated_node.trailing_whitespace.newline,
107+
)
108+
109+
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
110+
111+
return updated_node
112+
113+
# Process each generated test
114+
modified_tests = []
115+
for test in generated_tests.generated_tests:
116+
try:
117+
# Parse the test source code
118+
tree = cst.parse_module(test.generated_original_test_source)
119+
120+
# Transform the tree to add runtime comments
121+
transformer = RuntimeCommentTransformer()
122+
modified_tree = tree.visit(transformer)
123+
124+
# Convert back to source code
125+
modified_source = modified_tree.code
126+
127+
# Create a new GeneratedTests object with the modified source
128+
modified_test = GeneratedTests(
129+
generated_original_test_source=modified_source,
130+
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
131+
instrumented_perf_test_source=test.instrumented_perf_test_source,
132+
behavior_file_path=test.behavior_file_path,
133+
perf_file_path=test.perf_file_path,
134+
)
135+
modified_tests.append(modified_test)
136+
except Exception as e:
137+
# If parsing fails, keep the original test
138+
logger.debug(f"Failed to add runtime comments to test: {e}")
139+
modified_tests.append(test)
140+
141+
return GeneratedTestsList(generated_tests=modified_tests)

codeflash/code_utils/line_profile_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, qualified_name: str, decorator_name: str) -> None:
2424
"""Initialize the transformer.
2525
2626
Args:
27+
----
2728
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
2829
decorator_name: The name of the decorator to add.
2930
@@ -144,11 +145,13 @@ def add_decorator_to_qualified_function(module: cst.Module, qualified_name: str,
144145
"""Add a decorator to a function with the exact qualified name in the source code.
145146
146147
Args:
148+
----
147149
module: The Python source code as a CST module.
148150
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
149151
decorator_name: The name of the decorator to add.
150152
151153
Returns:
154+
-------
152155
The modified CST module.
153156
154157
"""

codeflash/code_utils/remove_generated_tests.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)