Skip to content

Commit bdb5ca6

Browse files
committed
Refactor
1 parent e25a185 commit bdb5ca6

File tree

3 files changed

+157
-154
lines changed

3 files changed

+157
-154
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import re
22

3-
from codeflash.models.models import GeneratedTestsList
3+
import libcst as cst
4+
5+
from codeflash.cli_cmds.console import logger
6+
from codeflash.code_utils.time_utils import format_time
7+
from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults
48

59

610
def remove_functions_from_generated_tests(
@@ -26,3 +30,111 @@ def remove_functions_from_generated_tests(
2630
new_generated_tests.append(generated_test)
2731

2832
return GeneratedTestsList(generated_tests=new_generated_tests)
33+
34+
35+
def add_runtime_comments_to_generated_tests(
36+
generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults
37+
) -> GeneratedTestsList:
38+
"""Add runtime performance comments to function calls in generated tests."""
39+
# Create dictionaries for fast lookup of runtime data
40+
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
41+
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
42+
43+
class RuntimeCommentTransformer(cst.CSTTransformer):
44+
def __init__(self):
45+
self.in_test_function = False
46+
self.current_test_name = None
47+
48+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
49+
if node.name.value.startswith("test_"):
50+
self.in_test_function = True
51+
self.current_test_name = node.name.value
52+
else:
53+
self.in_test_function = False
54+
self.current_test_name = None
55+
56+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
57+
if original_node.name.value.startswith("test_"):
58+
self.in_test_function = False
59+
self.current_test_name = None
60+
return updated_node
61+
62+
def leave_SimpleStatementLine(
63+
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
64+
) -> cst.SimpleStatementLine:
65+
if not self.in_test_function or not self.current_test_name:
66+
return updated_node
67+
68+
# Look for assignment statements that assign to codeflash_output
69+
# Handle both single statements and multiple statements on one line
70+
codeflash_assignment_found = False
71+
for stmt in updated_node.body:
72+
if isinstance(stmt, cst.Assign):
73+
if (
74+
len(stmt.targets) == 1
75+
and isinstance(stmt.targets[0].target, cst.Name)
76+
and stmt.targets[0].target.value == "codeflash_output"
77+
):
78+
codeflash_assignment_found = True
79+
break
80+
81+
if codeflash_assignment_found:
82+
# Find matching test cases by looking for this test function name in the test results
83+
matching_original_times = []
84+
matching_optimized_times = []
85+
86+
for invocation_id, runtimes in original_runtime_by_test.items():
87+
if invocation_id.test_function_name == self.current_test_name:
88+
matching_original_times.extend(runtimes)
89+
90+
for invocation_id, runtimes in optimized_runtime_by_test.items():
91+
if invocation_id.test_function_name == self.current_test_name:
92+
matching_optimized_times.extend(runtimes)
93+
94+
if matching_original_times and matching_optimized_times:
95+
original_time = min(matching_original_times)
96+
optimized_time = min(matching_optimized_times)
97+
98+
# Create the runtime comment
99+
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"
100+
101+
# Add comment to the trailing whitespace
102+
new_trailing_whitespace = cst.TrailingWhitespace(
103+
whitespace=cst.SimpleWhitespace(" "),
104+
comment=cst.Comment(comment_text),
105+
newline=updated_node.trailing_whitespace.newline,
106+
)
107+
108+
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
109+
110+
return updated_node
111+
112+
# Process each generated test
113+
modified_tests = []
114+
for test in generated_tests.generated_tests:
115+
try:
116+
# Parse the test source code
117+
tree = cst.parse_module(test.generated_original_test_source)
118+
119+
# Transform the tree to add runtime comments
120+
transformer = RuntimeCommentTransformer()
121+
modified_tree = tree.visit(transformer)
122+
123+
# Convert back to source code
124+
modified_source = modified_tree.code
125+
126+
# Create a new GeneratedTests object with the modified source
127+
modified_test = GeneratedTests(
128+
generated_original_test_source=modified_source,
129+
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
130+
instrumented_perf_test_source=test.instrumented_perf_test_source,
131+
behavior_file_path=test.behavior_file_path,
132+
perf_file_path=test.perf_file_path,
133+
)
134+
modified_tests.append(modified_test)
135+
except Exception as e:
136+
# If parsing fails, keep the original test
137+
logger.debug(f"Failed to add runtime comments to test: {e}")
138+
modified_tests.append(test)
139+
140+
return GeneratedTestsList(generated_tests=modified_tests)

codeflash/code_utils/time_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,42 @@ def humanize_runtime(time_in_ns: int) -> str:
4949
runtime_human = runtime_human_parts[0]
5050

5151
return f"{runtime_human} {units}"
52+
53+
54+
def format_time(nanoseconds: int) -> str:
55+
"""Format nanoseconds into a human-readable string with 3 significant digits when needed."""
56+
57+
def count_significant_digits(num: int) -> int:
58+
"""Count significant digits in an integer."""
59+
return len(str(abs(num)))
60+
61+
def format_with_precision(value: float, unit: str) -> str:
62+
"""Format a value with 3 significant digits precision."""
63+
if value >= 100:
64+
return f"{value:.0f}{unit}"
65+
if value >= 10:
66+
return f"{value:.1f}{unit}"
67+
return f"{value:.2f}{unit}"
68+
69+
if nanoseconds < 1_000:
70+
return f"{nanoseconds}ns"
71+
if nanoseconds < 1_000_000:
72+
# Convert to microseconds
73+
microseconds_int = nanoseconds // 1_000
74+
if count_significant_digits(microseconds_int) >= 3:
75+
return f"{microseconds_int}μs"
76+
microseconds_float = nanoseconds / 1_000
77+
return format_with_precision(microseconds_float, "μs")
78+
if nanoseconds < 1_000_000_000:
79+
# Convert to milliseconds
80+
milliseconds_int = nanoseconds // 1_000_000
81+
if count_significant_digits(milliseconds_int) >= 3:
82+
return f"{milliseconds_int}ms"
83+
milliseconds_float = nanoseconds / 1_000_000
84+
return format_with_precision(milliseconds_float, "ms")
85+
# Convert to seconds
86+
seconds_int = nanoseconds // 1_000_000_000
87+
if count_significant_digits(seconds_int) >= 3:
88+
return f"{seconds_int}s"
89+
seconds_float = nanoseconds / 1_000_000_000
90+
return format_with_precision(seconds_float, "s")

codeflash/optimization/function_optimizer.py

Lines changed: 5 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
N_TESTS_TO_GENERATE,
3737
TOTAL_LOOPING_TIME,
3838
)
39-
from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests
39+
from codeflash.code_utils.edit_generated_tests import (
40+
add_runtime_comments_to_generated_tests,
41+
remove_functions_from_generated_tests,
42+
)
4043
from codeflash.code_utils.formatter import format_code, sort_imports
4144
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
4245
from codeflash.code_utils.line_profile_utils import add_decorator_imports
@@ -319,7 +322,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
319322
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
320323
)
321324
# Add runtime comments to generated tests before creating the PR
322-
generated_tests = self.add_runtime_comments_to_generated_tests(
325+
generated_tests = add_runtime_comments_to_generated_tests(
323326
generated_tests,
324327
original_code_baseline.benchmarking_test_results,
325328
best_optimization.winning_benchmarking_test_results,
@@ -1270,154 +1273,3 @@ def cleanup_generated_files(self) -> None:
12701273
cleanup_paths(paths_to_cleanup)
12711274
if hasattr(get_run_tmp_file, "tmpdir"):
12721275
get_run_tmp_file.tmpdir.cleanup()
1273-
1274-
def add_runtime_comments_to_generated_tests(
1275-
self,
1276-
generated_tests: GeneratedTestsList,
1277-
original_test_results: TestResults,
1278-
optimized_test_results: TestResults,
1279-
) -> GeneratedTestsList:
1280-
"""Add runtime performance comments to function calls in generated tests."""
1281-
1282-
def format_time(nanoseconds: int) -> str:
1283-
"""Format nanoseconds into a human-readable string with 3 significant digits when needed."""
1284-
1285-
def count_significant_digits(num: int) -> int:
1286-
"""Count significant digits in an integer."""
1287-
return len(str(abs(num)))
1288-
1289-
def format_with_precision(value: float, unit: str) -> str:
1290-
"""Format a value with 3 significant digits precision."""
1291-
if value >= 100:
1292-
return f"{value:.0f}{unit}"
1293-
if value >= 10:
1294-
return f"{value:.1f}{unit}"
1295-
return f"{value:.2f}{unit}"
1296-
1297-
if nanoseconds < 1_000:
1298-
return f"{nanoseconds}ns"
1299-
if nanoseconds < 1_000_000:
1300-
# Convert to microseconds
1301-
microseconds_int = nanoseconds // 1_000
1302-
if count_significant_digits(microseconds_int) >= 3:
1303-
return f"{microseconds_int}μs"
1304-
microseconds_float = nanoseconds / 1_000
1305-
return format_with_precision(microseconds_float, "μs")
1306-
if nanoseconds < 1_000_000_000:
1307-
# Convert to milliseconds
1308-
milliseconds_int = nanoseconds // 1_000_000
1309-
if count_significant_digits(milliseconds_int) >= 3:
1310-
return f"{milliseconds_int}ms"
1311-
milliseconds_float = nanoseconds / 1_000_000
1312-
return format_with_precision(milliseconds_float, "ms")
1313-
# Convert to seconds
1314-
seconds_int = nanoseconds // 1_000_000_000
1315-
if count_significant_digits(seconds_int) >= 3:
1316-
return f"{seconds_int}s"
1317-
seconds_float = nanoseconds / 1_000_000_000
1318-
return format_with_precision(seconds_float, "s")
1319-
1320-
# Create dictionaries for fast lookup of runtime data
1321-
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
1322-
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
1323-
1324-
class RuntimeCommentTransformer(cst.CSTTransformer):
1325-
def __init__(self):
1326-
self.in_test_function = False
1327-
self.current_test_name = None
1328-
1329-
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
1330-
if node.name.value.startswith("test_"):
1331-
self.in_test_function = True
1332-
self.current_test_name = node.name.value
1333-
else:
1334-
self.in_test_function = False
1335-
self.current_test_name = None
1336-
1337-
def leave_FunctionDef(
1338-
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
1339-
) -> cst.FunctionDef:
1340-
if original_node.name.value.startswith("test_"):
1341-
self.in_test_function = False
1342-
self.current_test_name = None
1343-
return updated_node
1344-
1345-
def leave_SimpleStatementLine(
1346-
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
1347-
) -> cst.SimpleStatementLine:
1348-
if not self.in_test_function or not self.current_test_name:
1349-
return updated_node
1350-
1351-
# Look for assignment statements that assign to codeflash_output
1352-
# Handle both single statements and multiple statements on one line
1353-
codeflash_assignment_found = False
1354-
for stmt in updated_node.body:
1355-
if isinstance(stmt, cst.Assign):
1356-
if (
1357-
len(stmt.targets) == 1
1358-
and isinstance(stmt.targets[0].target, cst.Name)
1359-
and stmt.targets[0].target.value == "codeflash_output"
1360-
):
1361-
codeflash_assignment_found = True
1362-
break
1363-
1364-
if codeflash_assignment_found:
1365-
# Find matching test cases by looking for this test function name in the test results
1366-
matching_original_times = []
1367-
matching_optimized_times = []
1368-
1369-
for invocation_id, runtimes in original_runtime_by_test.items():
1370-
if invocation_id.test_function_name == self.current_test_name:
1371-
matching_original_times.extend(runtimes)
1372-
1373-
for invocation_id, runtimes in optimized_runtime_by_test.items():
1374-
if invocation_id.test_function_name == self.current_test_name:
1375-
matching_optimized_times.extend(runtimes)
1376-
1377-
if matching_original_times and matching_optimized_times:
1378-
original_time = min(matching_original_times)
1379-
optimized_time = min(matching_optimized_times)
1380-
1381-
# Create the runtime comment
1382-
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"
1383-
1384-
# Add comment to the trailing whitespace
1385-
new_trailing_whitespace = cst.TrailingWhitespace(
1386-
whitespace=cst.SimpleWhitespace(" "),
1387-
comment=cst.Comment(comment_text),
1388-
newline=updated_node.trailing_whitespace.newline,
1389-
)
1390-
1391-
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
1392-
1393-
return updated_node
1394-
1395-
# Process each generated test
1396-
modified_tests = []
1397-
for test in generated_tests.generated_tests:
1398-
try:
1399-
# Parse the test source code
1400-
tree = cst.parse_module(test.generated_original_test_source)
1401-
1402-
# Transform the tree to add runtime comments
1403-
transformer = RuntimeCommentTransformer()
1404-
modified_tree = tree.visit(transformer)
1405-
1406-
# Convert back to source code
1407-
modified_source = modified_tree.code
1408-
1409-
# Create a new GeneratedTests object with the modified source
1410-
modified_test = GeneratedTests(
1411-
generated_original_test_source=modified_source,
1412-
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
1413-
instrumented_perf_test_source=test.instrumented_perf_test_source,
1414-
behavior_file_path=test.behavior_file_path,
1415-
perf_file_path=test.perf_file_path,
1416-
)
1417-
modified_tests.append(modified_test)
1418-
except Exception as e:
1419-
# If parsing fails, keep the original test
1420-
logger.debug(f"Failed to add runtime comments to test: {e}")
1421-
modified_tests.append(test)
1422-
1423-
return GeneratedTestsList(generated_tests=modified_tests)

0 commit comments

Comments
 (0)