Skip to content

Commit ffff5f1

Browse files
author
Codeflash Bot
committed
all tests fixed
1 parent b91b61f commit ffff5f1

File tree

1 file changed

+95
-29
lines changed

1 file changed

+95
-29
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import platform
5+
from dataclasses import dataclass
56
from pathlib import Path
67
from typing import TYPE_CHECKING
78

@@ -20,6 +21,16 @@
2021
from codeflash.models.models import CodePosition
2122

2223

24+
@dataclass(frozen=True)
25+
class FunctionCallNodeArguments:
26+
args: list[ast.expr]
27+
keywords: list[ast.keyword]
28+
29+
30+
def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments:
31+
return FunctionCallNodeArguments(call_node.args, call_node.keywords)
32+
33+
2334
def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool:
2435
if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
2536
for pos in call_positions:
@@ -73,10 +84,12 @@ def __init__(
7384
def find_and_update_line_node(
7485
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
7586
) -> Iterable[ast.stmt] | None:
87+
return_statement = [test_node]
7688
call_node = None
7789
for node in ast.walk(test_node):
7890
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
7991
call_node = node
92+
all_args = get_call_arguments(call_node)
8093
if isinstance(node.func, ast.Name):
8194
function_name = node.func.id
8295

@@ -90,35 +103,33 @@ def find_and_update_line_node(
90103
func=ast.Attribute(
91104
value=ast.Call(
92105
func=ast.Attribute(
93-
value=ast.Name(id="inspect", ctx=ast.Load()),
94-
attr="signature",
95-
ctx=ast.Load()
106+
value=ast.Name(id="inspect", ctx=ast.Load()), attr="signature", ctx=ast.Load()
96107
),
97108
args=[ast.Name(id=function_name, ctx=ast.Load())],
98-
keywords=[]
109+
keywords=[],
99110
),
100111
attr="bind",
101-
ctx=ast.Load()
112+
ctx=ast.Load(),
102113
),
103-
args=[ast.Starred(value=ast.Attribute(value=call_node, attr="args", ctx=ast.Load()), ctx=ast.Load())],
104-
keywords=[ast.keyword(arg=None, value=ast.Attribute(value=call_node, attr="keywords", ctx=ast.Load()))]
114+
args=all_args.args,
115+
keywords=all_args.keywords,
105116
),
106-
lineno=test_node.lineno if hasattr(test_node, 'lineno') else 1,
107-
col_offset=test_node.col_offset if hasattr(test_node, 'col_offset') else 0
117+
lineno=test_node.lineno,
118+
col_offset=test_node.col_offset,
108119
)
109120

110121
apply_defaults = ast.Expr(
111122
value=ast.Call(
112123
func=ast.Attribute(
113124
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
114125
attr="apply_defaults",
115-
ctx=ast.Load()
126+
ctx=ast.Load(),
116127
),
117128
args=[],
118-
keywords=[]
129+
keywords=[],
119130
),
120131
lineno=test_node.lineno + 1,
121-
col_offset=test_node.col_offset
132+
col_offset=test_node.col_offset,
122133
)
123134

124135
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
@@ -135,12 +146,40 @@ def find_and_update_line_node(
135146
if self.mode == TestingMode.BEHAVIOR
136147
else []
137148
),
138-
*call_node.args,
149+
*(
150+
call_node.args
151+
if self.mode == TestingMode.PERFORMANCE
152+
else [
153+
ast.Starred(
154+
value=ast.Attribute(
155+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
156+
attr="args",
157+
ctx=ast.Load(),
158+
),
159+
ctx=ast.Load(),
160+
)
161+
]
162+
),
139163
]
140-
node.keywords = call_node.keywords
164+
node.keywords = (
165+
[
166+
ast.keyword(
167+
value=ast.Attribute(
168+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
169+
attr="kwargs",
170+
ctx=ast.Load(),
171+
)
172+
)
173+
]
174+
if self.mode == TestingMode.BEHAVIOR
175+
else call_node.keywords
176+
)
141177

142178
# Return the signature binding statements along with the test_node
143-
return [bind_call, apply_defaults, test_node]
179+
return_statement = (
180+
[bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node]
181+
)
182+
break
144183
if isinstance(node.func, ast.Attribute):
145184
function_to_test = node.func.attr
146185
if function_to_test == self.function_object.function_name:
@@ -158,38 +197,38 @@ def find_and_update_line_node(
158197
func=ast.Attribute(
159198
value=ast.Name(id="inspect", ctx=ast.Load()),
160199
attr="signature",
161-
ctx=ast.Load()
200+
ctx=ast.Load(),
162201
),
163-
args=function_name,
164-
keywords=[]
202+
args=[ast.parse(function_name, mode="eval").body],
203+
keywords=[],
165204
),
166205
attr="bind",
167-
ctx=ast.Load()
206+
ctx=ast.Load(),
168207
),
169-
args=call_node.args,
170-
keywords=call_node.keywords
208+
args=all_args.args,
209+
keywords=all_args.keywords,
171210
),
172211
lineno=test_node.lineno,
173-
col_offset=test_node.col_offset
212+
col_offset=test_node.col_offset,
174213
)
175214

176215
apply_defaults = ast.Expr(
177216
value=ast.Call(
178217
func=ast.Attribute(
179218
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
180219
attr="apply_defaults",
181-
ctx=ast.Load()
220+
ctx=ast.Load(),
182221
),
183222
args=[],
184-
keywords=[]
223+
keywords=[],
185224
),
186225
lineno=test_node.lineno + 1,
187-
col_offset=test_node.col_offset
226+
col_offset=test_node.col_offset,
188227
)
189228

190229
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
191230
node.args = [
192-
ast.Name(id=function_name, ctx=ast.Load()),
231+
ast.parse(function_name, mode="eval").body,
193232
ast.Constant(value=self.module_path),
194233
ast.Constant(value=test_class_name or None),
195234
ast.Constant(value=node_name),
@@ -204,12 +243,39 @@ def find_and_update_line_node(
204243
if self.mode == TestingMode.BEHAVIOR
205244
else []
206245
),
207-
*call_node.args,
246+
*(
247+
call_node.args
248+
if self.mode == TestingMode.PERFORMANCE
249+
else [
250+
ast.Starred(
251+
value=ast.Attribute(
252+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
253+
attr="args",
254+
ctx=ast.Load(),
255+
),
256+
ctx=ast.Load(),
257+
)
258+
]
259+
),
208260
]
209-
node.keywords = call_node.keywords
261+
node.keywords = (
262+
[
263+
ast.keyword(
264+
value=ast.Attribute(
265+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
266+
attr="kwargs",
267+
ctx=ast.Load(),
268+
)
269+
)
270+
]
271+
if self.mode == TestingMode.BEHAVIOR
272+
else call_node.keywords
273+
)
210274

211275
# Return the signature binding statements along with the test_node
212-
return_statement = [bind_call, apply_defaults, test_node]
276+
return_statement = (
277+
[bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node]
278+
)
213279
break
214280

215281
if call_node is None:

0 commit comments

Comments
 (0)