Skip to content

Commit ed9e9b3

Browse files
committed
Update instrument_existing_tests.py
1 parent e8699ec commit ed9e9b3

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def find_and_update_line_node(
7474
) -> Iterable[ast.stmt] | None:
7575
call_node = None
7676
await_node = None
77-
77+
7878
for node in ast.walk(test_node):
7979
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
8080
call_node = node
@@ -123,9 +123,13 @@ def find_and_update_line_node(
123123
]
124124
node.keywords = call_node.keywords
125125
break
126-
126+
127127
# Check for awaited function calls
128-
elif isinstance(node, ast.Await) and isinstance(node.value, ast.Call) and node_in_call_position(node.value, self.call_positions):
128+
elif (
129+
isinstance(node, ast.Await)
130+
and isinstance(node.value, ast.Call)
131+
and node_in_call_position(node.value, self.call_positions)
132+
):
129133
call_node = node.value
130134
await_node = node
131135
if isinstance(call_node.func, ast.Name):
@@ -192,7 +196,9 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
192196

193197
return node
194198

195-
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef, test_class_name: str | None = None) -> ast.AsyncFunctionDef:
199+
def visit_AsyncFunctionDef(
200+
self, node: ast.AsyncFunctionDef, test_class_name: str | None = None
201+
) -> ast.AsyncFunctionDef:
196202
"""Handle async function definitions by converting to sync and back."""
197203
# Convert to sync FunctionDef, process it, then convert back
198204
sync_node = ast.FunctionDef(
@@ -202,7 +208,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef, test_class_name: st
202208
decorator_list=node.decorator_list,
203209
returns=node.returns,
204210
lineno=node.lineno,
205-
col_offset=node.col_offset if hasattr(node, 'col_offset') else 0
211+
col_offset=node.col_offset if hasattr(node, "col_offset") else 0,
206212
)
207213
processed_sync = self.visit_FunctionDef(sync_node, test_class_name)
208214
# Convert back to AsyncFunctionDef
@@ -213,9 +219,9 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef, test_class_name: st
213219
decorator_list=processed_sync.decorator_list,
214220
returns=processed_sync.returns,
215221
lineno=processed_sync.lineno,
216-
col_offset=processed_sync.col_offset if hasattr(processed_sync, 'col_offset') else 0
222+
col_offset=processed_sync.col_offset if hasattr(processed_sync, "col_offset") else 0,
217223
)
218-
224+
219225
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
220226
if node.name.startswith("test_"):
221227
did_update = False

0 commit comments

Comments
 (0)