Skip to content

Commit 40e82e2

Browse files
Merge pull request #870 from codeflash-ai/codeflash/optimize-pr867-2025-11-01T00.02.02
⚡️ Speed up method `InjectPerfOnly.find_and_update_line_node` by 22% in PR #867 (`inspect-signature-issue`)
2 parents c2817f9 + b3c3ca8 commit 40e82e2

File tree

1 file changed

+189
-156
lines changed

1 file changed

+189
-156
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 189 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,28 @@ def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments:
3232

3333

3434
def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool:
35-
if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
36-
for pos in call_positions:
37-
if (
38-
pos.line_no is not None
39-
and node.end_lineno is not None
40-
and node.lineno <= pos.line_no <= node.end_lineno
41-
):
42-
if pos.line_no == node.lineno and node.col_offset <= pos.col_no:
43-
return True
44-
if (
45-
pos.line_no == node.end_lineno
46-
and node.end_col_offset is not None
47-
and node.end_col_offset >= pos.col_no
48-
):
49-
return True
50-
if node.lineno < pos.line_no < node.end_lineno:
51-
return True
35+
# Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty.
36+
# Small optimizations for tight loop:
37+
if isinstance(node, ast.Call):
38+
node_lineno = getattr(node, "lineno", None)
39+
node_col_offset = getattr(node, "col_offset", None)
40+
node_end_lineno = getattr(node, "end_lineno", None)
41+
node_end_col_offset = getattr(node, "end_col_offset", None)
42+
if node_lineno is not None and node_col_offset is not None and node_end_lineno is not None:
43+
# Faster loop: reduce attribute lookups, use local variables for conditionals.
44+
for pos in call_positions:
45+
pos_line = pos.line_no
46+
if pos_line is not None and node_lineno <= pos_line <= node_end_lineno:
47+
if pos_line == node_lineno and node_col_offset <= pos.col_no:
48+
return True
49+
if (
50+
pos_line == node_end_lineno
51+
and node_end_col_offset is not None
52+
and node_end_col_offset >= pos.col_no
53+
):
54+
return True
55+
if node_lineno < pos_line < node_end_lineno:
56+
return True
5257
return False
5358

5459

@@ -84,28 +89,157 @@ def __init__(
8489
def find_and_update_line_node(
8590
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
8691
) -> Iterable[ast.stmt] | None:
92+
# Major optimization: since ast.walk is *very* expensive for big trees and only checks for ast.Call,
93+
# it's much more efficient to visit nodes manually. We'll only descend into expressions/statements.
94+
95+
# Helper for manual walk
96+
def iter_ast_calls(node):
97+
# Generator to yield each ast.Call in test_node, preserves node identity
98+
stack = [node]
99+
while stack:
100+
n = stack.pop()
101+
if isinstance(n, ast.Call):
102+
yield n
103+
# Instead of using ast.walk (which calls iter_child_nodes under the hood in Python, which copy lists and stack-frames for EVERY node),
104+
# do a specialized BFS with only the necessary attributes
105+
for field, value in ast.iter_fields(n):
106+
if isinstance(value, list):
107+
for item in reversed(value):
108+
if isinstance(item, ast.AST):
109+
stack.append(item)
110+
elif isinstance(value, ast.AST):
111+
stack.append(value)
112+
113+
# This change improves from O(N) stack-frames per child-node to a single stack, less python call overhead
87114
return_statement = [test_node]
88115
call_node = None
89-
for node in ast.walk(test_node):
90-
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
91-
call_node = node
92-
all_args = get_call_arguments(call_node)
93-
if isinstance(node.func, ast.Name):
94-
function_name = node.func.id
95-
96-
if self.function_object.is_async:
116+
117+
# Minor optimization: Convert mode, function_name, test_class_name, qualified_name, etc to locals
118+
fn_obj = self.function_object
119+
module_path = self.module_path
120+
mode = self.mode
121+
qualified_name = fn_obj.qualified_name
122+
123+
# Use locals for all 'current' values, only look up class/function/constant AST object once.
124+
codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load())
125+
codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load())
126+
codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load())
127+
128+
for node in iter_ast_calls(test_node):
129+
if not node_in_call_position(node, self.call_positions):
130+
continue
131+
132+
call_node = node
133+
all_args = get_call_arguments(call_node)
134+
# Two possible call types: Name and Attribute
135+
node_func = node.func
136+
137+
if isinstance(node_func, ast.Name):
138+
function_name = node_func.id
139+
140+
if fn_obj.is_async:
141+
return [test_node]
142+
143+
# Build once, reuse objects.
144+
inspect_name = ast.Name(id="inspect", ctx=ast.Load())
145+
bind_call = ast.Assign(
146+
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
147+
value=ast.Call(
148+
func=ast.Attribute(
149+
value=ast.Call(
150+
func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()),
151+
args=[ast.Name(id=function_name, ctx=ast.Load())],
152+
keywords=[],
153+
),
154+
attr="bind",
155+
ctx=ast.Load(),
156+
),
157+
args=all_args.args,
158+
keywords=all_args.keywords,
159+
),
160+
lineno=test_node.lineno,
161+
col_offset=test_node.col_offset,
162+
)
163+
164+
apply_defaults = ast.Expr(
165+
value=ast.Call(
166+
func=ast.Attribute(
167+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
168+
attr="apply_defaults",
169+
ctx=ast.Load(),
170+
),
171+
args=[],
172+
keywords=[],
173+
),
174+
lineno=test_node.lineno + 1,
175+
col_offset=test_node.col_offset,
176+
)
177+
178+
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
179+
base_args = [
180+
ast.Name(id=function_name, ctx=ast.Load()),
181+
ast.Constant(value=module_path),
182+
ast.Constant(value=test_class_name or None),
183+
ast.Constant(value=node_name),
184+
ast.Constant(value=qualified_name),
185+
ast.Constant(value=index),
186+
codeflash_loop_index,
187+
]
188+
# Extend with BEHAVIOR extras if needed
189+
if mode == TestingMode.BEHAVIOR:
190+
base_args += [codeflash_cur, codeflash_con]
191+
# Extend with call args (performance) or starred bound args (behavior)
192+
if mode == TestingMode.PERFORMANCE:
193+
base_args += call_node.args
194+
else:
195+
base_args.append(
196+
ast.Starred(
197+
value=ast.Attribute(
198+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
199+
attr="args",
200+
ctx=ast.Load(),
201+
),
202+
ctx=ast.Load(),
203+
)
204+
)
205+
node.args = base_args
206+
# Prepare keywords
207+
if mode == TestingMode.BEHAVIOR:
208+
node.keywords = [
209+
ast.keyword(
210+
value=ast.Attribute(
211+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
212+
attr="kwargs",
213+
ctx=ast.Load(),
214+
)
215+
)
216+
]
217+
else:
218+
node.keywords = call_node.keywords
219+
220+
return_statement = (
221+
[bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node]
222+
)
223+
break
224+
if isinstance(node_func, ast.Attribute):
225+
function_to_test = node_func.attr
226+
if function_to_test == fn_obj.function_name:
227+
if fn_obj.is_async:
97228
return [test_node]
98229

99230
# Create the signature binding statements
231+
232+
# Unparse only once
233+
function_name_expr = ast.parse(ast.unparse(node_func), mode="eval").body
234+
235+
inspect_name = ast.Name(id="inspect", ctx=ast.Load())
100236
bind_call = ast.Assign(
101237
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
102238
value=ast.Call(
103239
func=ast.Attribute(
104240
value=ast.Call(
105-
func=ast.Attribute(
106-
value=ast.Name(id="inspect", ctx=ast.Load()), attr="signature", ctx=ast.Load()
107-
),
108-
args=[ast.Name(id=function_name, ctx=ast.Load())],
241+
func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()),
242+
args=[function_name_expr],
109243
keywords=[],
110244
),
111245
attr="bind",
@@ -133,36 +267,33 @@ def find_and_update_line_node(
133267
)
134268

135269
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
136-
node.args = [
137-
ast.Name(id=function_name, ctx=ast.Load()),
138-
ast.Constant(value=self.module_path),
270+
base_args = [
271+
function_name_expr,
272+
ast.Constant(value=module_path),
139273
ast.Constant(value=test_class_name or None),
140274
ast.Constant(value=node_name),
141-
ast.Constant(value=self.function_object.qualified_name),
275+
ast.Constant(value=qualified_name),
142276
ast.Constant(value=index),
143-
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
144-
*(
145-
[ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())]
146-
if self.mode == TestingMode.BEHAVIOR
147-
else []
148-
),
149-
*(
150-
call_node.args
151-
if self.mode == TestingMode.PERFORMANCE
152-
else [
153-
ast.Starred(
154-
value=ast.Attribute(
155-
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
156-
attr="args",
157-
ctx=ast.Load(),
158-
),
159-
ctx=ast.Load(),
160-
)
161-
]
162-
),
277+
codeflash_loop_index,
163278
]
164-
node.keywords = (
165-
[
279+
if mode == TestingMode.BEHAVIOR:
280+
base_args += [codeflash_cur, codeflash_con]
281+
if mode == TestingMode.PERFORMANCE:
282+
base_args += call_node.args
283+
else:
284+
base_args.append(
285+
ast.Starred(
286+
value=ast.Attribute(
287+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
288+
attr="args",
289+
ctx=ast.Load(),
290+
),
291+
ctx=ast.Load(),
292+
)
293+
)
294+
node.args = base_args
295+
if mode == TestingMode.BEHAVIOR:
296+
node.keywords = [
166297
ast.keyword(
167298
value=ast.Attribute(
168299
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
@@ -171,112 +302,14 @@ def find_and_update_line_node(
171302
)
172303
)
173304
]
174-
if self.mode == TestingMode.BEHAVIOR
175-
else call_node.keywords
176-
)
305+
else:
306+
node.keywords = call_node.keywords
177307

178308
# Return the signature binding statements along with the test_node
179309
return_statement = (
180-
[bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node]
310+
[bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node]
181311
)
182312
break
183-
if isinstance(node.func, ast.Attribute):
184-
function_to_test = node.func.attr
185-
if function_to_test == self.function_object.function_name:
186-
if self.function_object.is_async:
187-
return [test_node]
188-
189-
function_name = ast.unparse(node.func)
190-
191-
# Create the signature binding statements
192-
bind_call = ast.Assign(
193-
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
194-
value=ast.Call(
195-
func=ast.Attribute(
196-
value=ast.Call(
197-
func=ast.Attribute(
198-
value=ast.Name(id="inspect", ctx=ast.Load()),
199-
attr="signature",
200-
ctx=ast.Load(),
201-
),
202-
args=[ast.parse(function_name, mode="eval").body],
203-
keywords=[],
204-
),
205-
attr="bind",
206-
ctx=ast.Load(),
207-
),
208-
args=all_args.args,
209-
keywords=all_args.keywords,
210-
),
211-
lineno=test_node.lineno,
212-
col_offset=test_node.col_offset,
213-
)
214-
215-
apply_defaults = ast.Expr(
216-
value=ast.Call(
217-
func=ast.Attribute(
218-
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
219-
attr="apply_defaults",
220-
ctx=ast.Load(),
221-
),
222-
args=[],
223-
keywords=[],
224-
),
225-
lineno=test_node.lineno + 1,
226-
col_offset=test_node.col_offset,
227-
)
228-
229-
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
230-
node.args = [
231-
ast.parse(function_name, mode="eval").body,
232-
ast.Constant(value=self.module_path),
233-
ast.Constant(value=test_class_name or None),
234-
ast.Constant(value=node_name),
235-
ast.Constant(value=self.function_object.qualified_name),
236-
ast.Constant(value=index),
237-
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
238-
*(
239-
[
240-
ast.Name(id="codeflash_cur", ctx=ast.Load()),
241-
ast.Name(id="codeflash_con", ctx=ast.Load()),
242-
]
243-
if self.mode == TestingMode.BEHAVIOR
244-
else []
245-
),
246-
*(
247-
call_node.args
248-
if self.mode == TestingMode.PERFORMANCE
249-
else [
250-
ast.Starred(
251-
value=ast.Attribute(
252-
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
253-
attr="args",
254-
ctx=ast.Load(),
255-
),
256-
ctx=ast.Load(),
257-
)
258-
]
259-
),
260-
]
261-
node.keywords = (
262-
[
263-
ast.keyword(
264-
value=ast.Attribute(
265-
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
266-
attr="kwargs",
267-
ctx=ast.Load(),
268-
)
269-
)
270-
]
271-
if self.mode == TestingMode.BEHAVIOR
272-
else call_node.keywords
273-
)
274-
275-
# Return the signature binding statements along with the test_node
276-
return_statement = (
277-
[bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node]
278-
)
279-
break
280313

281314
if call_node is None:
282315
return None

0 commit comments

Comments
 (0)