Skip to content

Commit 077d4b7

Browse files
Merge branch 'main' into fix/skip-optimization-for-draft-prs
2 parents 9b0732a + 5298a68 commit 077d4b7

File tree

13 files changed

+1087
-132
lines changed

13 files changed

+1087
-132
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 86 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
import os
12
import re
3+
from pathlib import Path
24

35
import libcst as cst
46

57
from codeflash.cli_cmds.console import logger
6-
from codeflash.code_utils.time_utils import format_time
7-
from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults
8+
from codeflash.code_utils.time_utils import format_perf, format_time
9+
from codeflash.models.models import GeneratedTests, GeneratedTestsList, InvocationId
10+
from codeflash.result.critic import performance_gain
11+
from codeflash.verification.verification_utils import TestConfig
812

913

1014
def remove_functions_from_generated_tests(
@@ -33,40 +37,46 @@ def remove_functions_from_generated_tests(
3337

3438

3539
def add_runtime_comments_to_generated_tests(
36-
generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults
40+
test_cfg: TestConfig,
41+
generated_tests: GeneratedTestsList,
42+
original_runtimes: dict[InvocationId, list[int]],
43+
optimized_runtimes: dict[InvocationId, list[int]],
3744
) -> GeneratedTestsList:
3845
"""Add runtime performance comments to function calls in generated tests."""
39-
# Create dictionaries for fast lookup of runtime data
40-
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
41-
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
46+
tests_root = test_cfg.tests_root
47+
module_root = test_cfg.project_root_path
48+
rel_tests_root = tests_root.relative_to(module_root)
4249

50+
# TODO: reduce for loops to one
4351
class RuntimeCommentTransformer(cst.CSTTransformer):
44-
def __init__(self) -> None:
45-
self.in_test_function = False
46-
self.current_test_name: str | None = None
52+
def __init__(self, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
53+
self.test = test
54+
self.context_stack: list[str] = []
55+
self.tests_root = tests_root
56+
self.rel_tests_root = rel_tests_root
57+
58+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
59+
# Track when we enter a class
60+
self.context_stack.append(node.name.value)
61+
62+
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
63+
# Pop the context when we leave a class
64+
self.context_stack.pop()
65+
return updated_node
4766

4867
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
49-
if node.name.value.startswith("test_"):
50-
self.in_test_function = True
51-
self.current_test_name = node.name.value
52-
else:
53-
self.in_test_function = False
54-
self.current_test_name = None
55-
56-
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
57-
if original_node.name.value.startswith("test_"):
58-
self.in_test_function = False
59-
self.current_test_name = None
68+
self.context_stack.append(node.name.value)
69+
70+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
71+
# Pop the context when we leave a function
72+
self.context_stack.pop()
6073
return updated_node
6174

6275
def leave_SimpleStatementLine(
6376
self,
6477
original_node: cst.SimpleStatementLine, # noqa: ARG002
6578
updated_node: cst.SimpleStatementLine,
6679
) -> cst.SimpleStatementLine:
67-
if not self.in_test_function or not self.current_test_name:
68-
return updated_node
69-
7080
# Look for assignment statements that assign to codeflash_output
7181
# Handle both single statements and multiple statements on one line
7282
codeflash_assignment_found = False
@@ -83,30 +93,65 @@ def leave_SimpleStatementLine(
8393
# Find matching test cases by looking for this test function name in the test results
8494
matching_original_times = []
8595
matching_optimized_times = []
86-
87-
for invocation_id, runtimes in original_runtime_by_test.items():
88-
if invocation_id.test_function_name == self.current_test_name:
96+
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
97+
for invocation_id, runtimes in original_runtimes.items():
98+
qualified_name = (
99+
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
100+
if invocation_id.test_class_name
101+
else invocation_id.test_function_name
102+
)
103+
rel_path = (
104+
Path(invocation_id.test_module_path.replace(".", os.sep))
105+
.with_suffix(".py")
106+
.relative_to(self.rel_tests_root)
107+
)
108+
if qualified_name == ".".join(self.context_stack) and rel_path in [
109+
self.test.behavior_file_path.relative_to(self.tests_root),
110+
self.test.perf_file_path.relative_to(self.tests_root),
111+
]:
89112
matching_original_times.extend(runtimes)
90113

91-
for invocation_id, runtimes in optimized_runtime_by_test.items():
92-
if invocation_id.test_function_name == self.current_test_name:
114+
for invocation_id, runtimes in optimized_runtimes.items():
115+
qualified_name = (
116+
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
117+
if invocation_id.test_class_name
118+
else invocation_id.test_function_name
119+
)
120+
rel_path = (
121+
Path(invocation_id.test_module_path.replace(".", os.sep))
122+
.with_suffix(".py")
123+
.relative_to(self.rel_tests_root)
124+
)
125+
if qualified_name == ".".join(self.context_stack) and rel_path in [
126+
self.test.behavior_file_path.relative_to(self.tests_root),
127+
self.test.perf_file_path.relative_to(self.tests_root),
128+
]:
93129
matching_optimized_times.extend(runtimes)
94130

95131
if matching_original_times and matching_optimized_times:
96132
original_time = min(matching_original_times)
97133
optimized_time = min(matching_optimized_times)
98-
99-
# Create the runtime comment
100-
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"
101-
102-
# Add comment to the trailing whitespace
103-
new_trailing_whitespace = cst.TrailingWhitespace(
104-
whitespace=cst.SimpleWhitespace(" "),
105-
comment=cst.Comment(comment_text),
106-
newline=updated_node.trailing_whitespace.newline,
107-
)
108-
109-
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
134+
if original_time != 0 and optimized_time != 0:
135+
perf_gain = format_perf(
136+
abs(
137+
performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time)
138+
* 100
139+
)
140+
)
141+
status = "slower" if optimized_time > original_time else "faster"
142+
# Create the runtime comment
143+
comment_text = (
144+
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
145+
)
146+
147+
# Add comment to the trailing whitespace
148+
new_trailing_whitespace = cst.TrailingWhitespace(
149+
whitespace=cst.SimpleWhitespace(" "),
150+
comment=cst.Comment(comment_text),
151+
newline=updated_node.trailing_whitespace.newline,
152+
)
153+
154+
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
110155

111156
return updated_node
112157

@@ -118,7 +163,7 @@ def leave_SimpleStatementLine(
118163
tree = cst.parse_module(test.generated_original_test_source)
119164

120165
# Transform the tree to add runtime comments
121-
transformer = RuntimeCommentTransformer()
166+
transformer = RuntimeCommentTransformer(test, tests_root, rel_tests_root)
122167
modified_tree = tree.visit(transformer)
123168

124169
# Convert back to source code

codeflash/code_utils/env_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
2222
f.flush()
2323
tmp_file = Path(f.name)
2424
try:
25-
format_code(formatter_cmds, tmp_file, print_status=False)
25+
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=exit_on_failure)
2626
except Exception:
2727
exit_with_message(
2828
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
2929
error_on_exit=True,
3030
)
31-
3231
return return_code
3332

3433

codeflash/code_utils/formatter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def apply_formatter_cmds(
4343
path: Path,
4444
test_dir_str: Optional[str],
4545
print_status: bool, # noqa
46+
exit_on_failure: bool = True, # noqa
4647
) -> tuple[Path, str]:
4748
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
4849
formatter_name = cmds[0].lower()
@@ -84,8 +85,8 @@ def apply_formatter_cmds(
8485
expand=False,
8586
)
8687
console.print(panel)
87-
88-
raise e from None
88+
if exit_on_failure:
89+
raise e from None
8990

9091
return file_path, file_path.read_text(encoding="utf8")
9192

@@ -106,6 +107,7 @@ def format_code(
106107
optimized_function: str = "",
107108
check_diff: bool = False, # noqa
108109
print_status: bool = True, # noqa
110+
exit_on_failure: bool = True, # noqa
109111
) -> str:
110112
with tempfile.TemporaryDirectory() as test_dir_str:
111113
if isinstance(path, str):
@@ -138,7 +140,9 @@ def format_code(
138140
)
139141
return original_code
140142
# TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
141-
_, formatted_code = apply_formatter_cmds(formatter_cmds, path, test_dir_str=None, print_status=print_status)
143+
_, formatted_code = apply_formatter_cmds(
144+
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
145+
)
142146
logger.debug(f"Formatted {path} with commands: {formatter_cmds}")
143147
return formatted_code
144148

codeflash/code_utils/time_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,15 @@ def format_time(nanoseconds: int) -> str:
8585

8686
# This should never be reached, but included for completeness
8787
return f"{nanoseconds}ns"
88+
89+
90+
def format_perf(percentage: float) -> str:
91+
"""Format percentage into a human-readable string with 3 significant digits when needed."""
92+
percentage_abs = abs(percentage)
93+
if percentage_abs >= 100:
94+
return f"{percentage:.0f}"
95+
if percentage_abs >= 10:
96+
return f"{percentage:.1f}"
97+
if percentage_abs >= 1:
98+
return f"{percentage:.2f}"
99+
return f"{percentage:.3f}"

codeflash/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
557557

558558
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
559559
# Efficient single traversal, directly accumulating into a dict.
560+
# can track mins here and only sums can be return in total_passed_runtime
560561
by_id: dict[InvocationId, list[int]] = {}
561562
for result in self.test_results:
562563
if result.did_pass:

codeflash/optimization/function_optimizer.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
341341
optimized_function=best_optimization.candidate.source_code,
342342
)
343343

344-
existing_tests = existing_tests_source_for(
345-
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
346-
function_to_all_tests,
347-
tests_root=self.test_cfg.tests_root,
348-
)
349-
350344
original_code_combined = original_helper_code.copy()
351345
original_code_combined[explanation.file_path] = self.function_to_optimize_source_code
352346
new_code_combined = new_helper_code.copy()
@@ -360,15 +354,26 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
360354
generated_tests = remove_functions_from_generated_tests(
361355
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
362356
)
357+
original_runtime_by_test = (
358+
original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
359+
)
360+
optimized_runtime_by_test = (
361+
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
362+
)
363363
# Add runtime comments to generated tests before creating the PR
364364
generated_tests = add_runtime_comments_to_generated_tests(
365-
generated_tests,
366-
original_code_baseline.benchmarking_test_results,
367-
best_optimization.winning_benchmarking_test_results,
365+
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
368366
)
369367
generated_tests_str = "\n\n".join(
370368
[test.generated_original_test_source for test in generated_tests.generated_tests]
371369
)
370+
existing_tests = existing_tests_source_for(
371+
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
372+
function_to_all_tests,
373+
test_cfg=self.test_cfg,
374+
original_runtimes_all=original_runtime_by_test,
375+
optimized_runtimes_all=optimized_runtime_by_test,
376+
)
372377
if concolic_test_str:
373378
generated_tests_str += "\n\n" + concolic_test_str
374379

0 commit comments

Comments
 (0)