Skip to content

Commit e25a185

Browse files
committed
first commit
1 parent 19dcbfb commit e25a185

File tree

3 files changed

+163
-9
lines changed

3 files changed

+163
-9
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 162 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
N_TESTS_TO_GENERATE,
3737
TOTAL_LOOPING_TIME,
3838
)
39+
from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests
3940
from codeflash.code_utils.formatter import format_code, sort_imports
4041
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
4142
from codeflash.code_utils.line_profile_utils import add_decorator_imports
42-
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
4343
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
4444
from codeflash.code_utils.time_utils import humanize_runtime
4545
from codeflash.context import code_context_extractor
@@ -265,10 +265,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
265265
},
266266
)
267267

268-
generated_tests = remove_functions_from_generated_tests(
269-
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
270-
)
271-
272268
if best_optimization:
273269
logger.info("Best candidate:")
274270
code_print(best_optimization.candidate.source_code)
@@ -295,8 +291,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
295291
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
296292
)
297293

298-
self.log_successful_optimization(explanation, generated_tests, exp_type)
299-
300294
self.replace_function_and_helpers_with_optimized_code(
301295
code_context=code_context, optimized_code=best_optimization.candidate.source_code
302296
)
@@ -321,6 +315,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
321315
if original_code_baseline.coverage_results
322316
else "Coverage data not available"
323317
)
318+
generated_tests = remove_functions_from_generated_tests(
319+
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
320+
)
321+
# Add runtime comments to generated tests before creating the PR
322+
generated_tests = self.add_runtime_comments_to_generated_tests(
323+
generated_tests,
324+
original_code_baseline.benchmarking_test_results,
325+
best_optimization.winning_benchmarking_test_results,
326+
)
324327
generated_tests_str = "\n\n".join(
325328
[test.generated_original_test_source for test in generated_tests.generated_tests]
326329
)
@@ -345,6 +348,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
345348
original_helper_code,
346349
self.function_to_optimize.file_path,
347350
)
351+
self.log_successful_optimization(explanation, generated_tests, exp_type)
348352

349353
if not best_optimization:
350354
return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}")
@@ -1266,3 +1270,154 @@ def cleanup_generated_files(self) -> None:
12661270
cleanup_paths(paths_to_cleanup)
12671271
if hasattr(get_run_tmp_file, "tmpdir"):
12681272
get_run_tmp_file.tmpdir.cleanup()
1273+
1274+
def add_runtime_comments_to_generated_tests(
1275+
self,
1276+
generated_tests: GeneratedTestsList,
1277+
original_test_results: TestResults,
1278+
optimized_test_results: TestResults,
1279+
) -> GeneratedTestsList:
1280+
"""Add runtime performance comments to function calls in generated tests."""
1281+
1282+
def format_time(nanoseconds: int) -> str:
1283+
"""Format nanoseconds into a human-readable string with 3 significant digits when needed."""
1284+
1285+
def count_significant_digits(num: int) -> int:
1286+
"""Count significant digits in an integer."""
1287+
return len(str(abs(num)))
1288+
1289+
def format_with_precision(value: float, unit: str) -> str:
1290+
"""Format a value with 3 significant digits precision."""
1291+
if value >= 100:
1292+
return f"{value:.0f}{unit}"
1293+
if value >= 10:
1294+
return f"{value:.1f}{unit}"
1295+
return f"{value:.2f}{unit}"
1296+
1297+
if nanoseconds < 1_000:
1298+
return f"{nanoseconds}ns"
1299+
if nanoseconds < 1_000_000:
1300+
# Convert to microseconds
1301+
microseconds_int = nanoseconds // 1_000
1302+
if count_significant_digits(microseconds_int) >= 3:
1303+
return f"{microseconds_int}μs"
1304+
microseconds_float = nanoseconds / 1_000
1305+
return format_with_precision(microseconds_float, "μs")
1306+
if nanoseconds < 1_000_000_000:
1307+
# Convert to milliseconds
1308+
milliseconds_int = nanoseconds // 1_000_000
1309+
if count_significant_digits(milliseconds_int) >= 3:
1310+
return f"{milliseconds_int}ms"
1311+
milliseconds_float = nanoseconds / 1_000_000
1312+
return format_with_precision(milliseconds_float, "ms")
1313+
# Convert to seconds
1314+
seconds_int = nanoseconds // 1_000_000_000
1315+
if count_significant_digits(seconds_int) >= 3:
1316+
return f"{seconds_int}s"
1317+
seconds_float = nanoseconds / 1_000_000_000
1318+
return format_with_precision(seconds_float, "s")
1319+
1320+
# Create dictionaries for fast lookup of runtime data
1321+
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
1322+
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
1323+
1324+
class RuntimeCommentTransformer(cst.CSTTransformer):
1325+
def __init__(self):
1326+
self.in_test_function = False
1327+
self.current_test_name = None
1328+
1329+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
1330+
if node.name.value.startswith("test_"):
1331+
self.in_test_function = True
1332+
self.current_test_name = node.name.value
1333+
else:
1334+
self.in_test_function = False
1335+
self.current_test_name = None
1336+
1337+
def leave_FunctionDef(
1338+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
1339+
) -> cst.FunctionDef:
1340+
if original_node.name.value.startswith("test_"):
1341+
self.in_test_function = False
1342+
self.current_test_name = None
1343+
return updated_node
1344+
1345+
def leave_SimpleStatementLine(
1346+
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
1347+
) -> cst.SimpleStatementLine:
1348+
if not self.in_test_function or not self.current_test_name:
1349+
return updated_node
1350+
1351+
# Look for assignment statements that assign to codeflash_output
1352+
# Handle both single statements and multiple statements on one line
1353+
codeflash_assignment_found = False
1354+
for stmt in updated_node.body:
1355+
if isinstance(stmt, cst.Assign):
1356+
if (
1357+
len(stmt.targets) == 1
1358+
and isinstance(stmt.targets[0].target, cst.Name)
1359+
and stmt.targets[0].target.value == "codeflash_output"
1360+
):
1361+
codeflash_assignment_found = True
1362+
break
1363+
1364+
if codeflash_assignment_found:
1365+
# Find matching test cases by looking for this test function name in the test results
1366+
matching_original_times = []
1367+
matching_optimized_times = []
1368+
1369+
for invocation_id, runtimes in original_runtime_by_test.items():
1370+
if invocation_id.test_function_name == self.current_test_name:
1371+
matching_original_times.extend(runtimes)
1372+
1373+
for invocation_id, runtimes in optimized_runtime_by_test.items():
1374+
if invocation_id.test_function_name == self.current_test_name:
1375+
matching_optimized_times.extend(runtimes)
1376+
1377+
if matching_original_times and matching_optimized_times:
1378+
original_time = min(matching_original_times)
1379+
optimized_time = min(matching_optimized_times)
1380+
1381+
# Create the runtime comment
1382+
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"
1383+
1384+
# Add comment to the trailing whitespace
1385+
new_trailing_whitespace = cst.TrailingWhitespace(
1386+
whitespace=cst.SimpleWhitespace(" "),
1387+
comment=cst.Comment(comment_text),
1388+
newline=updated_node.trailing_whitespace.newline,
1389+
)
1390+
1391+
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
1392+
1393+
return updated_node
1394+
1395+
# Process each generated test
1396+
modified_tests = []
1397+
for test in generated_tests.generated_tests:
1398+
try:
1399+
# Parse the test source code
1400+
tree = cst.parse_module(test.generated_original_test_source)
1401+
1402+
# Transform the tree to add runtime comments
1403+
transformer = RuntimeCommentTransformer()
1404+
modified_tree = tree.visit(transformer)
1405+
1406+
# Convert back to source code
1407+
modified_source = modified_tree.code
1408+
1409+
# Create a new GeneratedTests object with the modified source
1410+
modified_test = GeneratedTests(
1411+
generated_original_test_source=modified_source,
1412+
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
1413+
instrumented_perf_test_source=test.instrumented_perf_test_source,
1414+
behavior_file_path=test.behavior_file_path,
1415+
perf_file_path=test.perf_file_path,
1416+
)
1417+
modified_tests.append(modified_test)
1418+
except Exception as e:
1419+
# If parsing fails, keep the original test
1420+
logger.debug(f"Failed to add runtime comments to test: {e}")
1421+
modified_tests.append(test)
1422+
1423+
return GeneratedTestsList(generated_tests=modified_tests)

tests/test_remove_functions_from_generated_tests.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from pathlib import Path
22

33
import pytest
4-
5-
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
4+
from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests
65
from codeflash.models.models import GeneratedTests, GeneratedTestsList
76

87

0 commit comments

Comments
 (0)