Skip to content

Commit 4ab32b9

Browse files
committed
un-nesting fns
1 parent 9d31359 commit 4ab32b9

File tree

1 file changed

+160
-148
lines changed

1 file changed

+160
-148
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 160 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ class CfoVisitor(ast.NodeVisitor):
5050
and reports their location relative to the function they're in.
5151
"""
5252

53-
def __init__(self, qualified_name: str, source_code: str) -> None:
53+
def __init__(self, function_name: str, source_code: str) -> None:
5454
self.source_lines = source_code.splitlines()
55-
self.name = qualified_name.split(".")[-1]
55+
self.name = function_name
5656
self.results: list[int] = [] # map actual line number to line number in ast
5757

5858
def visit_Call(self, node): # type: ignore[no-untyped-def] # noqa: ANN201, ANN001
@@ -71,13 +71,166 @@ def _get_called_func_name(self, node): # type: ignore[no-untyped-def] # noqa: A
7171
return None
7272

7373

74-
def find_codeflash_output_assignments(qualified_name: str, source_code: str) -> list[int]:
74+
def find_codeflash_output_assignments(function_name: str, source_code: str) -> list[int]:
7575
tree = ast.parse(source_code)
76-
visitor = CfoVisitor(qualified_name, source_code)
76+
visitor = CfoVisitor(function_name, source_code)
7777
visitor.visit(tree)
7878
return visitor.results
7979

8080

81+
class Finder(cst.CSTVisitor):
82+
def __init__(self, name: str) -> None:
83+
super().__init__()
84+
self.found = False
85+
self.name = name
86+
87+
def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa : ANN001
88+
func_expr = call_node.func
89+
if isinstance(func_expr, cst.Name):
90+
if func_expr.value == self.name:
91+
self.found = True
92+
elif isinstance(func_expr, cst.Attribute): # noqa : SIM102
93+
if func_expr.attr.value == self.name:
94+
self.found = True
95+
96+
97+
# TODO: reduce for loops to one
98+
class RuntimeCommentTransformer(cst.CSTTransformer):
99+
def __init__(
100+
self,
101+
qualified_name: str,
102+
module: cst.Module,
103+
test: GeneratedTests,
104+
tests_root: Path,
105+
rel_tests_root: Path,
106+
original_runtimes: dict[InvocationId, list[int]],
107+
optimized_runtimes: dict[InvocationId, list[int]],
108+
) -> None:
109+
super().__init__()
110+
self.test = test
111+
self.context_stack: list[str] = []
112+
self.tests_root = tests_root
113+
self.rel_tests_root = rel_tests_root
114+
self.module = module
115+
self.cfo_locs: list[int] = []
116+
self.cfo_idx_loc_to_look_at: int = -1
117+
self.name = qualified_name.split(".")[-1]
118+
self.original_runtimes = original_runtimes
119+
self.optimized_runtimes = optimized_runtimes
120+
121+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
122+
# Track when we enter a class
123+
self.context_stack.append(node.name.value)
124+
125+
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
126+
# Pop the context when we leave a class
127+
self.context_stack.pop()
128+
return updated_node
129+
130+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
131+
# convert function body to ast normalized string and find occurrences of codeflash_output
132+
body_code = dedent(self.module.code_for_node(node.body))
133+
normalized_body_code = ast.unparse(ast.parse(body_code))
134+
self.cfo_locs = sorted(
135+
find_codeflash_output_assignments(self.name, normalized_body_code)
136+
) # sorted in order we will encounter them
137+
self.cfo_idx_loc_to_look_at = -1
138+
self.context_stack.append(node.name.value)
139+
140+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
141+
# Pop the context when we leave a function
142+
self.context_stack.pop()
143+
return updated_node
144+
145+
def leave_SimpleStatementLine(
146+
self,
147+
original_node: cst.SimpleStatementLine, # noqa: ARG002
148+
updated_node: cst.SimpleStatementLine,
149+
) -> cst.SimpleStatementLine:
150+
# Check if this statement line contains a call to self.name
151+
if self._contains_myfunc_call(updated_node): # type: ignore[no-untyped-call]
152+
# Find matching test cases by looking for this test function name in the test results
153+
self.cfo_idx_loc_to_look_at += 1
154+
matching_original_times = []
155+
matching_optimized_times = []
156+
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
157+
for invocation_id, runtimes in self.original_runtimes.items():
158+
# get position here and match in if condition
159+
qualified_name = (
160+
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
161+
if invocation_id.test_class_name
162+
else invocation_id.test_function_name
163+
)
164+
rel_path = (
165+
Path(invocation_id.test_module_path.replace(".", os.sep))
166+
.with_suffix(".py")
167+
.relative_to(self.rel_tests_root)
168+
)
169+
if (
170+
qualified_name == ".".join(self.context_stack)
171+
and rel_path
172+
in [
173+
self.test.behavior_file_path.relative_to(self.tests_root),
174+
self.test.perf_file_path.relative_to(self.tests_root),
175+
]
176+
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
177+
):
178+
matching_original_times.extend(runtimes)
179+
180+
for invocation_id, runtimes in self.optimized_runtimes.items():
181+
# get position here and match in if condition
182+
qualified_name = (
183+
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
184+
if invocation_id.test_class_name
185+
else invocation_id.test_function_name
186+
)
187+
rel_path = (
188+
Path(invocation_id.test_module_path.replace(".", os.sep))
189+
.with_suffix(".py")
190+
.relative_to(self.rel_tests_root)
191+
)
192+
if (
193+
qualified_name == ".".join(self.context_stack)
194+
and rel_path
195+
in [
196+
self.test.behavior_file_path.relative_to(self.tests_root),
197+
self.test.perf_file_path.relative_to(self.tests_root),
198+
]
199+
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
200+
):
201+
matching_optimized_times.extend(runtimes)
202+
203+
if matching_original_times and matching_optimized_times:
204+
original_time = min(matching_original_times)
205+
optimized_time = min(matching_optimized_times)
206+
if original_time != 0 and optimized_time != 0:
207+
perf_gain = format_perf(
208+
abs(
209+
performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time)
210+
* 100
211+
)
212+
)
213+
status = "slower" if optimized_time > original_time else "faster"
214+
# Create the runtime comment
215+
comment_text = (
216+
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
217+
)
218+
return updated_node.with_changes(
219+
trailing_whitespace=cst.TrailingWhitespace(
220+
whitespace=cst.SimpleWhitespace(" "),
221+
comment=cst.Comment(comment_text),
222+
newline=updated_node.trailing_whitespace.newline,
223+
)
224+
)
225+
return updated_node
226+
227+
def _contains_myfunc_call(self, node): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001
228+
"""Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
229+
finder = Finder(self.name)
230+
node.visit(finder)
231+
return finder.found
232+
233+
81234
def add_runtime_comments_to_generated_tests(
82235
qualified_name: str,
83236
test_cfg: TestConfig,
@@ -90,149 +243,6 @@ def add_runtime_comments_to_generated_tests(
90243
module_root = test_cfg.project_root_path
91244
rel_tests_root = tests_root.relative_to(module_root)
92245

93-
# TODO: reduce for loops to one
94-
class RuntimeCommentTransformer(cst.CSTTransformer):
95-
def __init__(
96-
self, qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
97-
) -> None:
98-
super().__init__()
99-
self.test = test
100-
self.context_stack: list[str] = []
101-
self.tests_root = tests_root
102-
self.rel_tests_root = rel_tests_root
103-
self.module = module
104-
self.cfo_locs: list[int] = []
105-
self.cfo_idx_loc_to_look_at: int = -1
106-
self.name = qualified_name.split(".")[-1]
107-
108-
def visit_ClassDef(self, node: cst.ClassDef) -> None:
109-
# Track when we enter a class
110-
self.context_stack.append(node.name.value)
111-
112-
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
113-
# Pop the context when we leave a class
114-
self.context_stack.pop()
115-
return updated_node
116-
117-
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
118-
# convert function body to ast normalized string and find occurrences of codeflash_output
119-
body_code = dedent(self.module.code_for_node(node.body))
120-
normalized_body_code = ast.unparse(ast.parse(body_code))
121-
self.cfo_locs = sorted(
122-
find_codeflash_output_assignments(qualified_name, normalized_body_code)
123-
) # sorted in order we will encounter them
124-
self.cfo_idx_loc_to_look_at = -1
125-
self.context_stack.append(node.name.value)
126-
127-
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
128-
# Pop the context when we leave a function
129-
self.context_stack.pop()
130-
return updated_node
131-
132-
def leave_SimpleStatementLine(
133-
self,
134-
original_node: cst.SimpleStatementLine, # noqa: ARG002
135-
updated_node: cst.SimpleStatementLine,
136-
) -> cst.SimpleStatementLine:
137-
# Check if this statement line contains a call to self.name
138-
if self._contains_myfunc_call(updated_node): # type: ignore[no-untyped-call]
139-
# Find matching test cases by looking for this test function name in the test results
140-
self.cfo_idx_loc_to_look_at += 1
141-
matching_original_times = []
142-
matching_optimized_times = []
143-
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
144-
for invocation_id, runtimes in original_runtimes.items():
145-
# get position here and match in if condition
146-
qualified_name = (
147-
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
148-
if invocation_id.test_class_name
149-
else invocation_id.test_function_name
150-
)
151-
rel_path = (
152-
Path(invocation_id.test_module_path.replace(".", os.sep))
153-
.with_suffix(".py")
154-
.relative_to(self.rel_tests_root)
155-
)
156-
if (
157-
qualified_name == ".".join(self.context_stack)
158-
and rel_path
159-
in [
160-
self.test.behavior_file_path.relative_to(self.tests_root),
161-
self.test.perf_file_path.relative_to(self.tests_root),
162-
]
163-
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
164-
):
165-
matching_original_times.extend(runtimes)
166-
167-
for invocation_id, runtimes in optimized_runtimes.items():
168-
# get position here and match in if condition
169-
qualified_name = (
170-
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
171-
if invocation_id.test_class_name
172-
else invocation_id.test_function_name
173-
)
174-
rel_path = (
175-
Path(invocation_id.test_module_path.replace(".", os.sep))
176-
.with_suffix(".py")
177-
.relative_to(self.rel_tests_root)
178-
)
179-
if (
180-
qualified_name == ".".join(self.context_stack)
181-
and rel_path
182-
in [
183-
self.test.behavior_file_path.relative_to(self.tests_root),
184-
self.test.perf_file_path.relative_to(self.tests_root),
185-
]
186-
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
187-
):
188-
matching_optimized_times.extend(runtimes)
189-
190-
if matching_original_times and matching_optimized_times:
191-
original_time = min(matching_original_times)
192-
optimized_time = min(matching_optimized_times)
193-
if original_time != 0 and optimized_time != 0:
194-
perf_gain = format_perf(
195-
abs(
196-
performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time)
197-
* 100
198-
)
199-
)
200-
status = "slower" if optimized_time > original_time else "faster"
201-
# Create the runtime comment
202-
comment_text = (
203-
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
204-
)
205-
return updated_node.with_changes(
206-
trailing_whitespace=cst.TrailingWhitespace(
207-
whitespace=cst.SimpleWhitespace(" "),
208-
comment=cst.Comment(comment_text),
209-
newline=updated_node.trailing_whitespace.newline,
210-
)
211-
)
212-
return updated_node
213-
214-
def _contains_myfunc_call(self, node): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001
215-
"""Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
216-
217-
class Finder(cst.CSTVisitor):
218-
def __init__(self, name: str) -> None:
219-
super().__init__()
220-
self.found = False
221-
self.name = name
222-
223-
def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa : ANN001
224-
func_expr = call_node.func
225-
if isinstance(func_expr, cst.Name):
226-
if func_expr.value == self.name:
227-
self.found = True
228-
elif isinstance(func_expr, cst.Attribute): # noqa : SIM102
229-
if func_expr.attr.value == self.name:
230-
self.found = True
231-
232-
finder = Finder(self.name)
233-
node.visit(finder)
234-
return finder.found
235-
236246
# Process each generated test
237247
modified_tests = []
238248
for test in generated_tests.generated_tests:
@@ -241,7 +251,9 @@ def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa
241251
tree = cst.parse_module(test.generated_original_test_source)
242252
# Transform the tree to add runtime comments
243253
# qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
244-
transformer = RuntimeCommentTransformer(qualified_name, tree, test, tests_root, rel_tests_root)
254+
transformer = RuntimeCommentTransformer(
255+
qualified_name, tree, test, tests_root, rel_tests_root, original_runtimes, optimized_runtimes
256+
)
245257
modified_tree = tree.visit(transformer)
246258

247259
# Convert back to source code

0 commit comments

Comments
 (0)