diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index db08f8afc..948792b7a 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -20,22 +20,25 @@ def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool: if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"): + node_lineno = node.lineno + node_col_offset = node.col_offset + node_end_lineno = getattr(node, "end_lineno", None) + node_end_col_offset = getattr(node, "end_col_offset", None) for pos in call_positions: - if ( - pos.line_no is not None - and node.end_lineno is not None - and node.lineno <= pos.line_no <= node.end_lineno - ): - if pos.line_no == node.lineno and node.col_offset <= pos.col_no: - return True - if ( - pos.line_no == node.end_lineno - and node.end_col_offset is not None - and node.end_col_offset >= pos.col_no - ): - return True - if node.lineno < pos.line_no < node.end_lineno: - return True + pos_line_no = pos.line_no + pos_col_no = pos.col_no + if pos_line_no is not None and node_end_lineno is not None: + if node_lineno <= pos_line_no <= node_end_lineno: + if pos_line_no == node_lineno and node_col_offset <= pos_col_no: + return True + if ( + pos_line_no == node_end_lineno + and node_end_col_offset is not None + and node_end_col_offset >= pos_col_no + ): + return True + if node_lineno < pos_line_no < node_end_lineno: + return True return False