Skip to content

Commit 30410ff

Browse files
committed
works
1 parent 0d566bf commit 30410ff

File tree

2 files changed

+56
-22
lines changed

2 files changed

+56
-22
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import os
12
import re
3+
from pathlib import Path
24

35
import libcst as cst
46

57
from codeflash.cli_cmds.console import logger
68
from codeflash.code_utils.time_utils import format_time
79
from codeflash.models.models import GeneratedTests, GeneratedTestsList
10+
from codeflash.verification.verification_utils import TestConfig
811

912

1013
def remove_functions_from_generated_tests(
@@ -33,38 +36,43 @@ def remove_functions_from_generated_tests(
3336

3437

3538
def add_runtime_comments_to_generated_tests(
36-
generated_tests: GeneratedTestsList, original_runtimes: dict, optimized_runtimes: dict
39+
test_cfg: TestConfig, generated_tests: GeneratedTestsList, original_runtimes: dict, optimized_runtimes: dict
3740
) -> GeneratedTestsList:
3841
"""Add runtime performance comments to function calls in generated tests."""
42+
tests_root = test_cfg.tests_root
43+
module_root = test_cfg.project_root_path
44+
rel_tests_root = tests_root.relative_to(module_root)
3945

4046
# TODO: reduce for loops to one
4147
class RuntimeCommentTransformer(cst.CSTTransformer):
42-
def __init__(self) -> None:
43-
self.in_test_function = False
44-
self.current_test_name: str | None = None
48+
def __init__(self, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
49+
self.test = test
50+
self.context_stack = []
51+
self.tests_root = tests_root
52+
self.rel_tests_root = rel_tests_root
53+
54+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
55+
# Track when we enter a class
56+
self.context_stack.append(node.name.value)
57+
58+
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
59+
# Pop the context when we leave a class
60+
self.context_stack.pop()
61+
return updated_node
4562

4663
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
47-
if node.name.value.startswith("test_"):
48-
self.in_test_function = True
49-
self.current_test_name = node.name.value
50-
else:
51-
self.in_test_function = False
52-
self.current_test_name = None
53-
54-
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
55-
if original_node.name.value.startswith("test_"):
56-
self.in_test_function = False
57-
self.current_test_name = None
64+
self.context_stack.append(node.name.value)
65+
66+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
67+
# Pop the context when we leave a function
68+
self.context_stack.pop()
5869
return updated_node
5970

6071
def leave_SimpleStatementLine(
6172
self,
6273
original_node: cst.SimpleStatementLine, # noqa: ARG002
6374
updated_node: cst.SimpleStatementLine,
6475
) -> cst.SimpleStatementLine:
65-
if not self.in_test_function or not self.current_test_name:
66-
return updated_node
67-
6876
# Look for assignment statements that assign to codeflash_output
6977
# Handle both single statements and multiple statements on one line
7078
codeflash_assignment_found = False
@@ -83,11 +91,37 @@ def leave_SimpleStatementLine(
8391
matching_optimized_times = []
8492
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
8593
for invocation_id, runtimes in original_runtimes.items():
86-
if invocation_id.test_function_name == self.current_test_name:
94+
qualified_name = (
95+
invocation_id.test_class_name + "." + invocation_id.test_function_name
96+
if invocation_id.test_class_name
97+
else invocation_id.test_function_name
98+
)
99+
rel_path = (
100+
Path(invocation_id.test_module_path.replace(".", os.sep))
101+
.with_suffix(".py")
102+
.relative_to(self.rel_tests_root)
103+
)
104+
if qualified_name == ".".join(self.context_stack) and rel_path in [
105+
self.test.behavior_file_path.relative_to(self.tests_root),
106+
self.test.perf_file_path.relative_to(self.tests_root),
107+
]:
87108
matching_original_times.extend(runtimes)
88109

89110
for invocation_id, runtimes in optimized_runtimes.items():
90-
if invocation_id.test_function_name == self.current_test_name:
111+
qualified_name = (
112+
invocation_id.test_class_name + "." + invocation_id.test_function_name
113+
if invocation_id.test_class_name
114+
else invocation_id.test_function_name
115+
)
116+
rel_path = (
117+
Path(invocation_id.test_module_path.replace(".", os.sep))
118+
.with_suffix(".py")
119+
.relative_to(self.rel_tests_root)
120+
)
121+
if qualified_name == ".".join(self.context_stack) and rel_path in [
122+
self.test.behavior_file_path.relative_to(self.tests_root),
123+
self.test.perf_file_path.relative_to(self.tests_root),
124+
]:
91125
matching_optimized_times.extend(runtimes)
92126

93127
if matching_original_times and matching_optimized_times:
@@ -116,7 +150,7 @@ def leave_SimpleStatementLine(
116150
tree = cst.parse_module(test.generated_original_test_source)
117151

118152
# Transform the tree to add runtime comments
119-
transformer = RuntimeCommentTransformer()
153+
transformer = RuntimeCommentTransformer(test, tests_root, rel_tests_root)
120154
modified_tree = tree.visit(transformer)
121155

122156
# Convert back to source code

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
362362
)
363363
# Add runtime comments to generated tests before creating the PR
364364
generated_tests = add_runtime_comments_to_generated_tests(
365-
generated_tests, original_runtime_by_test, optimized_runtime_by_test
365+
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
366366
)
367367
generated_tests_str = "\n\n".join(
368368
[test.generated_original_test_source for test in generated_tests.generated_tests]

0 commit comments

Comments
 (0)