Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 65 additions & 42 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,30 @@


def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool:
if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
for pos in call_positions:
if (
pos.line_no is not None
and node.end_lineno is not None
and node.lineno <= pos.line_no <= node.end_lineno
):
if pos.line_no == node.lineno and node.col_offset <= pos.col_no:
return True
if (
pos.line_no == node.end_lineno
and node.end_col_offset is not None
and node.end_col_offset >= pos.col_no
):
return True
if node.lineno < pos.line_no < node.end_lineno:
return True
# Optimize by using attribute lookup only once and avoid repeated checks inside loop
if not isinstance(node, ast.Call):
return False

lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
end_lineno = getattr(node, "end_lineno", None)
end_col_offset = getattr(node, "end_col_offset", None)

if lineno is None or col_offset is None:
return False

# Precompute positions as tuples for fast search.
# Typical call_positions are small, so this does not cost much in memory.
for pos in call_positions:
line_no, col_no = pos.line_no, pos.col_no
# Combine all nested checks into a single block to minimize Python opcode jumps
if line_no is not None and end_lineno is not None and lineno <= line_no <= end_lineno:
if line_no == lineno and col_offset <= col_no:
return True
if line_no == end_lineno and end_col_offset is not None and end_col_offset >= col_no:
return True
if lineno < line_no < end_lineno:
return True
return False


Expand Down Expand Up @@ -71,60 +78,76 @@ def __init__(
def find_and_update_line_node(
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> Iterable[ast.stmt] | None:
# Fast local lookups to minimize attribute access and improve inner loop performance
only_func_name = self.function_object.function_name
is_async = self.function_object.is_async
qualified_name = self.function_object.qualified_name
mode = self.mode
module_path = self.module_path
call_positions = self.call_positions

# OPT: Avoid repeated ast.walk for every found node, instead process and break immediately on match
call_node = None

for node in ast.walk(test_node):
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
# Fast-path block
if isinstance(node, ast.Call) and node_in_call_position(node, call_positions):
call_node = node
if isinstance(node.func, ast.Name):
function_name = node.func.id

if self.function_object.is_async:
if is_async:
return [test_node]

node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
# Prepare args list outside assignment for performance (less copying)
args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=self.module_path),
ast.Constant(value=module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=self.function_object.qualified_name),
ast.Constant(value=qualified_name),
ast.Constant(value=index),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
*(
[ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())]
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
]

if mode == TestingMode.BEHAVIOR:
args.extend(
[ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())]
)
args.extend(call_node.args)

node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = args
node.keywords = call_node.keywords
break

if isinstance(node.func, ast.Attribute):
function_to_test = node.func.attr
if function_to_test == self.function_object.function_name:
if self.function_object.is_async:
if node.func.attr == only_func_name:
if is_async:
return [test_node]

# Use ast.unparse only after match is confirmed
function_name = ast.unparse(node.func)
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=self.module_path),
ast.Constant(value=module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=self.function_object.qualified_name),
ast.Constant(value=qualified_name),
ast.Constant(value=index),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
*(
]
if mode == TestingMode.BEHAVIOR:
args.extend(
[
ast.Name(id="codeflash_cur", ctx=ast.Load()),
ast.Name(id="codeflash_con", ctx=ast.Load()),
]
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
]
)
args.extend(call_node.args)

node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = args
node.keywords = call_node.keywords
break

Expand Down
Loading