Skip to content

Commit b91b61f

Browse files
committed
wip
1 parent 376ca0d commit b91b61f

File tree

1 file changed

+85
-2
lines changed

1 file changed

+85
-2
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,44 @@ def find_and_update_line_node(
8383
if self.function_object.is_async:
8484
return [test_node]
8585

86+
# Create the signature binding statements
87+
bind_call = ast.Assign(
88+
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
89+
value=ast.Call(
90+
func=ast.Attribute(
91+
value=ast.Call(
92+
func=ast.Attribute(
93+
value=ast.Name(id="inspect", ctx=ast.Load()),
94+
attr="signature",
95+
ctx=ast.Load()
96+
),
97+
args=[ast.Name(id=function_name, ctx=ast.Load())],
98+
keywords=[]
99+
),
100+
attr="bind",
101+
ctx=ast.Load()
102+
),
103+
args=[ast.Starred(value=ast.Attribute(value=call_node, attr="args", ctx=ast.Load()), ctx=ast.Load())],
104+
keywords=[ast.keyword(arg=None, value=ast.Attribute(value=call_node, attr="keywords", ctx=ast.Load()))]
105+
),
106+
lineno=test_node.lineno if hasattr(test_node, 'lineno') else 1,
107+
col_offset=test_node.col_offset if hasattr(test_node, 'col_offset') else 0
108+
)
109+
110+
apply_defaults = ast.Expr(
111+
value=ast.Call(
112+
func=ast.Attribute(
113+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
114+
attr="apply_defaults",
115+
ctx=ast.Load()
116+
),
117+
args=[],
118+
keywords=[]
119+
),
120+
lineno=test_node.lineno + 1,
121+
col_offset=test_node.col_offset
122+
)
123+
86124
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
87125
node.args = [
88126
ast.Name(id=function_name, ctx=ast.Load()),
@@ -100,14 +138,55 @@ def find_and_update_line_node(
100138
*call_node.args,
101139
]
102140
node.keywords = call_node.keywords
103-
break
141+
142+
# Return the signature binding statements along with the test_node
143+
return [bind_call, apply_defaults, test_node]
104144
if isinstance(node.func, ast.Attribute):
105145
function_to_test = node.func.attr
106146
if function_to_test == self.function_object.function_name:
107147
if self.function_object.is_async:
108148
return [test_node]
109149

110150
function_name = ast.unparse(node.func)
151+
152+
# Create the signature binding statements
153+
bind_call = ast.Assign(
154+
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
155+
value=ast.Call(
156+
func=ast.Attribute(
157+
value=ast.Call(
158+
func=ast.Attribute(
159+
value=ast.Name(id="inspect", ctx=ast.Load()),
160+
attr="signature",
161+
ctx=ast.Load()
162+
),
163+
args=function_name,
164+
keywords=[]
165+
),
166+
attr="bind",
167+
ctx=ast.Load()
168+
),
169+
args=call_node.args,
170+
keywords=call_node.keywords
171+
),
172+
lineno=test_node.lineno,
173+
col_offset=test_node.col_offset
174+
)
175+
176+
apply_defaults = ast.Expr(
177+
value=ast.Call(
178+
func=ast.Attribute(
179+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
180+
attr="apply_defaults",
181+
ctx=ast.Load()
182+
),
183+
args=[],
184+
keywords=[]
185+
),
186+
lineno=test_node.lineno + 1,
187+
col_offset=test_node.col_offset
188+
)
189+
111190
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
112191
node.args = [
113192
ast.Name(id=function_name, ctx=ast.Load()),
@@ -128,11 +207,14 @@ def find_and_update_line_node(
128207
*call_node.args,
129208
]
130209
node.keywords = call_node.keywords
210+
211+
# Return the signature binding statements along with the test_node
212+
return_statement = [bind_call, apply_defaults, test_node]
131213
break
132214

133215
if call_node is None:
134216
return None
135-
return [test_node]
217+
return return_statement
136218

137219
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
138220
# TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
@@ -590,6 +672,7 @@ def inject_profiling_into_existing_test(
590672
ast.Import(names=[ast.alias(name="time")]),
591673
ast.Import(names=[ast.alias(name="gc")]),
592674
ast.Import(names=[ast.alias(name="os")]),
675+
ast.Import(names=[ast.alias(name="inspect")]),
593676
]
594677
if mode == TestingMode.BEHAVIOR:
595678
new_imports.extend(

0 commit comments

Comments
 (0)