Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 133 additions & 14 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from __future__ import annotations

import ast
import os
import re
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING, Union

import libcst as cst

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.models.models import GeneratedTests, GeneratedTestsList, InvocationId
from codeflash.models.models import GeneratedTests, GeneratedTestsList
from codeflash.result.critic import performance_gain
from codeflash.verification.verification_utils import TestConfig

if TYPE_CHECKING:
from codeflash.models.models import InvocationId
from codeflash.verification.verification_utils import TestConfig


def remove_functions_from_generated_tests(
Expand Down Expand Up @@ -36,6 +44,94 @@ def remove_functions_from_generated_tests(
return GeneratedTestsList(generated_tests=new_generated_tests)


class CfoVisitor(ast.NodeVisitor):
"""AST visitor that finds all assignments to a variable named 'codeflash_output'.

and reports their location relative to the function they're in.
"""

def __init__(self, source_code: str) -> None:
self.source_lines = source_code.splitlines()
self.results: list[int] = [] # map actual line number to line number in ast

def _is_codeflash_output_target(self, target: Union[ast.expr, list]) -> bool: # type: ignore[type-arg]
"""Check if the assignment target is the variable 'codeflash_output'."""
if isinstance(target, ast.Name):
return target.id == "codeflash_output"
if isinstance(target, (ast.Tuple, ast.List)):
# Handle tuple/list unpacking: a, codeflash_output, b = values
return any(self._is_codeflash_output_target(elt) for elt in target.elts)
if isinstance(target, (ast.Subscript, ast.Attribute)):
Comment on lines +57 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 76% (0.76x) speedup for CfoVisitor._is_codeflash_output_target in codeflash/code_utils/edit_generated_tests.py

⏱️ Runtime : 1.33 milliseconds 757 microseconds (best of 93 runs)

📝 Explanation and details Here is an optimized rewrite of your program with a focus on improving performance in the `_is_codeflash_output_target` method. These changes focus on avoiding unnecessary generator creation, reducing function calls, and using faster data structures and comparison strategies.

Key optimizations:

  • Use type(...) is ... instead of isinstance() for frequently called type checks for speed.
  • Unroll the any(...) generator over target.elts into an explicit for loop to reduce function call and overhead in the hot path.
  • Avoid recursive generator creation where possible; use explicit returns.
  • Avoid isinstance(..., (a, b)) for hot paths (for small type hierarchies and known types); it's slightly faster to compare types directly.

These micro-optimizations especially benefit scenarios with large ASTs or heavy tuple/list unpackings.
The function's logic and return value are identical to your original version.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 104 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast
from typing import Union

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.edit_generated_tests import CfoVisitor

# unit tests

# Helper function to parse a string into an AST assignment target node
def parse_assignment_target(expr: str):
    """Parses an assignment target (left-hand side) expression and returns the AST node."""
    # We add '= 1' to make it a valid assignment statement
    node = ast.parse(f"{expr} = 1").body[0]
    # In Python AST, targets is a list; we test each target individually
    return node.targets[0]

# 1. Basic Test Cases

def test_simple_name_true():
    # Should return True for variable named 'codeflash_output'
    visitor = CfoVisitor("")
    target = parse_assignment_target("codeflash_output")
    codeflash_output = visitor._is_codeflash_output_target(target) # 731ns -> 722ns (1.25% faster)

def test_simple_name_false():
    # Should return False for variable not named 'codeflash_output'
    visitor = CfoVisitor("")
    target = parse_assignment_target("not_codeflash_output")
    codeflash_output = visitor._is_codeflash_output_target(target) # 651ns -> 591ns (10.2% faster)

def test_tuple_with_codeflash_output():
    # Should return True if 'codeflash_output' is one of the tuple elements
    visitor = CfoVisitor("")
    target = parse_assignment_target("a, codeflash_output, b")
    codeflash_output = visitor._is_codeflash_output_target(target) # 3.09μs -> 1.21μs (155% faster)

def test_tuple_without_codeflash_output():
    # Should return False if 'codeflash_output' is not present in tuple
    visitor = CfoVisitor("")
    target = parse_assignment_target("a, b, c")
    codeflash_output = visitor._is_codeflash_output_target(target) # 2.37μs -> 1.23μs (92.7% faster)

def test_list_with_codeflash_output():
    # Should return True if 'codeflash_output' is one of the list elements
    visitor = CfoVisitor("")
    target = parse_assignment_target("[a, codeflash_output, b]")
    codeflash_output = visitor._is_codeflash_output_target(target) # 3.02μs -> 1.14μs (165% faster)

def test_list_without_codeflash_output():
    # Should return False if 'codeflash_output' is not present in list
    visitor = CfoVisitor("")
    target = parse_assignment_target("[a, b, c]")
    codeflash_output = visitor._is_codeflash_output_target(target) # 2.35μs -> 1.21μs (93.5% faster)

# 2. Edge Test Cases

def test_nested_tuple_with_codeflash_output():
    # Should return True if 'codeflash_output' is nested inside a tuple
    visitor = CfoVisitor("")
    target = parse_assignment_target("a, (b, codeflash_output), c")
    codeflash_output = visitor._is_codeflash_output_target(target) # 3.99μs -> 1.75μs (127% faster)

def test_nested_list_with_codeflash_output():
    # Should return True if 'codeflash_output' is nested inside a list
    visitor = CfoVisitor("")
    target = parse_assignment_target("[a, [b, codeflash_output], c]")
    codeflash_output = visitor._is_codeflash_output_target(target) # 3.86μs -> 1.87μs (106% faster)

def test_deeply_nested_tuple_and_list():
    # Should return True if 'codeflash_output' is deeply nested
    visitor = CfoVisitor("")
    target = parse_assignment_target("a, (b, [c, (d, codeflash_output)])")
    codeflash_output = visitor._is_codeflash_output_target(target) # 5.55μs -> 2.73μs (104% faster)

def test_attribute_assignment():
    # Should return False for attribute assignment (e.g., obj.codeflash_output)
    visitor = CfoVisitor("")
    target = parse_assignment_target("obj.codeflash_output")
    codeflash_output = visitor._is_codeflash_output_target(target) # 922ns -> 661ns (39.5% faster)

def test_subscript_assignment():
    # Should return False for subscript assignment (e.g., codeflash_output[0])
    visitor = CfoVisitor("")
    target = parse_assignment_target("codeflash_output[0]")
    codeflash_output = visitor._is_codeflash_output_target(target) # 1.12μs -> 672ns (67.0% faster)

def test_tuple_with_attribute_and_name():
    # Should return True if at least one element is 'codeflash_output', even if others are attributes
    visitor = CfoVisitor("")
    target = parse_assignment_target("(obj.a, codeflash_output)")
    codeflash_output = visitor._is_codeflash_output_target(target) # 3.29μs -> 1.20μs (173% faster)

def test_tuple_with_only_attributes():
    # Should return False if all elements are attributes
    visitor = CfoVisitor("")
    target = parse_assignment_target("(obj.a, obj.codeflash_output)")
    codeflash_output = visitor._is_codeflash_output_target(target) # 2.44μs -> 1.14μs (114% faster)

def test_list_with_subscript_and_name():
    # Should return True if at least one element is 'codeflash_output', even if others are subscripts
    visitor = CfoVisitor("")
    target = parse_assignment_target("[a[0], codeflash_output]")
    codeflash_output = visitor._is_codeflash_output_target(target) # 3.08μs -> 1.26μs (144% faster)

def test_list_with_only_subscripts():
    # Should return False if all elements are subscripts
    visitor = CfoVisitor("")
    target = parse_assignment_target("[a[0], b[1]]")
    codeflash_output = visitor._is_codeflash_output_target(target) # 2.30μs -> 1.22μs (88.5% faster)

def test_empty_tuple():
    # Should return False for empty tuple
    visitor = CfoVisitor("")
    target = parse_assignment_target("()")
    codeflash_output = visitor._is_codeflash_output_target(target) # 1.36μs -> 731ns (86.3% faster)

def test_empty_list():
    # Should return False for empty list
    visitor = CfoVisitor("")
    target = parse_assignment_target("[]")
    codeflash_output = visitor._is_codeflash_output_target(target) # 1.36μs -> 781ns (74.4% faster)

def test_non_ast_input():
    # Should return False for non-AST input (shouldn't raise)
    visitor = CfoVisitor("")
    codeflash_output = visitor._is_codeflash_output_target("not an ast node") # 901ns -> 722ns (24.8% faster)

def test_tuple_with_non_ast_elements():
    # Should return False if tuple contains non-AST elements
    visitor = CfoVisitor("")
    # Manually create a tuple node with non-AST elements
    fake_tuple = ast.Tuple(elts=["not_ast", "also_not_ast"])
    codeflash_output = visitor._is_codeflash_output_target(fake_tuple) # 2.33μs -> 1.10μs (111% faster)

def test_list_with_non_ast_elements():
    # Should return False if list contains non-AST elements
    visitor = CfoVisitor("")
    fake_list = ast.List(elts=["not_ast", "also_not_ast"])
    codeflash_output = visitor._is_codeflash_output_target(fake_list) # 2.20μs -> 1.10μs (100% faster)

# 3. Large Scale Test Cases

def test_large_tuple_with_codeflash_output_at_start():
    # Should return True if 'codeflash_output' is the first element in a large tuple
    visitor = CfoVisitor("")
    names = ["codeflash_output"] + [f"var{i}" for i in range(999)]
    expr = ", ".join(names)
    target = parse_assignment_target(expr)
    codeflash_output = visitor._is_codeflash_output_target(target) # 3.23μs -> 1.21μs (166% faster)

def test_large_tuple_with_codeflash_output_at_end():
    # Should return True if 'codeflash_output' is the last element in a large tuple
    visitor = CfoVisitor("")
    names = [f"var{i}" for i in range(999)] + ["codeflash_output"]
    expr = ", ".join(names)
    target = parse_assignment_target(expr)
    codeflash_output = visitor._is_codeflash_output_target(target) # 135μs -> 75.4μs (80.2% faster)

def test_large_tuple_without_codeflash_output():
    # Should return False if 'codeflash_output' is not present in a large tuple
    visitor = CfoVisitor("")
    names = [f"var{i}" for i in range(1000)]
    expr = ", ".join(names)
    target = parse_assignment_target(expr)
    codeflash_output = visitor._is_codeflash_output_target(target) # 134μs -> 74.7μs (80.0% faster)

def test_large_list_with_codeflash_output_in_middle():
    # Should return True if 'codeflash_output' is in the middle of a large list
    visitor = CfoVisitor("")
    names = [f"var{i}" for i in range(500)] + ["codeflash_output"] + [f"var{i}" for i in range(500, 999)]
    expr = "[" + ", ".join(names) + "]"
    target = parse_assignment_target(expr)
    codeflash_output = visitor._is_codeflash_output_target(target) # 70.3μs -> 38.6μs (82.2% faster)




from __future__ import annotations

import ast
from typing import Union

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.edit_generated_tests import CfoVisitor

# unit tests

@pytest.fixture
def visitor():
    # Provide a dummy source code string; not used in _is_codeflash_output_target
    return CfoVisitor("dummy\nsource\ncode")

# --------------------
# 1. Basic Test Cases
# --------------------

def test_simple_name_match(visitor):
    # Should return True for variable named 'codeflash_output'
    node = ast.Name(id="codeflash_output")
    codeflash_output = visitor._is_codeflash_output_target(node) # 781ns -> 721ns (8.32% faster)

def test_simple_name_no_match(visitor):
    # Should return False for other variable names
    node = ast.Name(id="not_codeflash_output")
    codeflash_output = visitor._is_codeflash_output_target(node) # 651ns -> 631ns (3.17% faster)

def test_tuple_with_codeflash_output(visitor):
    # Should return True if 'codeflash_output' is in tuple assignment
    node = ast.Tuple(elts=[
        ast.Name(id="a"),
        ast.Name(id="codeflash_output"),
        ast.Name(id="b"),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 3.00μs -> 1.17μs (156% faster)

def test_tuple_without_codeflash_output(visitor):
    # Should return False if 'codeflash_output' is not in tuple
    node = ast.Tuple(elts=[
        ast.Name(id="a"),
        ast.Name(id="b"),
        ast.Name(id="c"),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 2.19μs -> 1.23μs (78.2% faster)

def test_list_with_codeflash_output(visitor):
    # Should return True if 'codeflash_output' is in list assignment
    node = ast.List(elts=[
        ast.Name(id="x"),
        ast.Name(id="codeflash_output"),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 2.81μs -> 1.15μs (143% faster)

def test_list_without_codeflash_output(visitor):
    # Should return False if 'codeflash_output' is not in list
    node = ast.List(elts=[
        ast.Name(id="x"),
        ast.Name(id="y"),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 2.05μs -> 1.04μs (97.3% faster)

def test_subscript_target(visitor):
    # Should return False for subscript assignment (e.g., codeflash_output[0] = ...)
    node = ast.Subscript(
        value=ast.Name(id="codeflash_output"),
        slice=ast.Index(value=ast.Constant(value=0)),
        ctx=ast.Store(),
    )
    codeflash_output = visitor._is_codeflash_output_target(node) # 931ns -> 661ns (40.8% faster)

def test_attribute_target(visitor):
    # Should return False for attribute assignment (e.g., codeflash_output.attr = ...)
    node = ast.Attribute(
        value=ast.Name(id="codeflash_output"),
        attr="attr",
        ctx=ast.Store(),
    )
    codeflash_output = visitor._is_codeflash_output_target(node) # 852ns -> 701ns (21.5% faster)

# --------------------
# 2. Edge Test Cases
# --------------------

def test_nested_tuple_with_codeflash_output(visitor):
    # Should return True if 'codeflash_output' is nested inside tuple
    node = ast.Tuple(elts=[
        ast.Name(id="a"),
        ast.Tuple(elts=[
            ast.Name(id="b"),
            ast.Name(id="codeflash_output"),
        ]),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 3.87μs -> 1.75μs (121% faster)

def test_nested_list_with_codeflash_output(visitor):
    # Should return True if 'codeflash_output' is nested inside list
    node = ast.List(elts=[
        ast.List(elts=[
            ast.Name(id="codeflash_output"),
        ]),
        ast.Name(id="c"),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 3.24μs -> 1.59μs (103% faster)

def test_deeply_nested_structure_without_codeflash_output(visitor):
    # Should return False if 'codeflash_output' is not present anywhere
    node = ast.Tuple(elts=[
        ast.List(elts=[
            ast.Tuple(elts=[
                ast.Name(id="x")
            ])
        ])
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 3.00μs -> 1.91μs (56.5% faster)

def test_tuple_with_subscript_and_codeflash_output(visitor):
    # Should return True if 'codeflash_output' is present, even if other elements are subscripts
    node = ast.Tuple(elts=[
        ast.Subscript(
            value=ast.Name(id="foo"),
            slice=ast.Index(value=ast.Constant(value=0)),
            ctx=ast.Store(),
        ),
        ast.Name(id="codeflash_output"),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 2.88μs -> 1.13μs (155% faster)

def test_tuple_with_only_subscripts(visitor):
    # Should return False if tuple contains only subscript elements
    node = ast.Tuple(elts=[
        ast.Subscript(
            value=ast.Name(id="foo"),
            slice=ast.Index(value=ast.Constant(value=0)),
            ctx=ast.Store(),
        ),
        ast.Subscript(
            value=ast.Name(id="bar"),
            slice=ast.Index(value=ast.Constant(value=1)),
            ctx=ast.Store(),
        ),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 2.25μs -> 1.12μs (101% faster)

def test_tuple_with_attribute_and_codeflash_output(visitor):
    # Should return True if 'codeflash_output' is present, even if other elements are attributes
    node = ast.Tuple(elts=[
        ast.Attribute(
            value=ast.Name(id="foo"),
            attr="bar",
            ctx=ast.Store(),
        ),
        ast.Name(id="codeflash_output"),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 2.90μs -> 1.12μs (158% faster)

def test_tuple_with_only_attributes(visitor):
    # Should return False if tuple contains only attribute elements
    node = ast.Tuple(elts=[
        ast.Attribute(
            value=ast.Name(id="foo"),
            attr="bar",
            ctx=ast.Store(),
        ),
        ast.Attribute(
            value=ast.Name(id="baz"),
            attr="qux",
            ctx=ast.Store(),
        ),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 2.31μs -> 1.16μs (99.2% faster)

def test_empty_tuple(visitor):
    # Should return False for empty tuple
    node = ast.Tuple(elts=[])
    codeflash_output = visitor._is_codeflash_output_target(node) # 1.37μs -> 722ns (90.0% faster)

def test_empty_list(visitor):
    # Should return False for empty list
    node = ast.List(elts=[])
    codeflash_output = visitor._is_codeflash_output_target(node) # 1.35μs -> 801ns (68.8% faster)

def test_invalid_type(visitor):
    # Should return False for unsupported node types (e.g., ast.Constant)
    node = ast.Constant(value=42)
    codeflash_output = visitor._is_codeflash_output_target(node) # 911ns -> 691ns (31.8% faster)

def test_mixed_types_with_codeflash_output(visitor):
    # Should return True if 'codeflash_output' is present among mixed node types
    node = ast.Tuple(elts=[
        ast.Constant(value=1),
        ast.Name(id="codeflash_output"),
        ast.Attribute(
            value=ast.Name(id="foo"),
            attr="bar",
            ctx=ast.Store(),
        ),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 3.04μs -> 1.16μs (161% faster)

def test_mixed_types_without_codeflash_output(visitor):
    # Should return False if 'codeflash_output' is not present among mixed node types
    node = ast.Tuple(elts=[
        ast.Constant(value=1),
        ast.Attribute(
            value=ast.Name(id="foo"),
            attr="bar",
            ctx=ast.Store(),
        ),
    ])
    codeflash_output = visitor._is_codeflash_output_target(node) # 2.30μs -> 1.11μs (107% faster)

# --------------------
# 3. Large Scale Test Cases
# --------------------

def test_large_tuple_with_codeflash_output_at_start(visitor):
    # Should return True if 'codeflash_output' is at the start of a large tuple
    node = ast.Tuple(elts=[ast.Name(id="codeflash_output")] + [ast.Name(id=f"x{i}") for i in range(999)])
    codeflash_output = visitor._is_codeflash_output_target(node) # 2.83μs -> 1.14μs (148% faster)

def test_large_tuple_with_codeflash_output_at_end(visitor):
    # Should return True if 'codeflash_output' is at the end of a large tuple
    node = ast.Tuple(elts=[ast.Name(id=f"x{i}") for i in range(999)] + [ast.Name(id="codeflash_output")])
    codeflash_output = visitor._is_codeflash_output_target(node) # 146μs -> 85.4μs (72.1% faster)

def test_large_tuple_with_codeflash_output_in_middle(visitor):
    # Should return True if 'codeflash_output' is in the middle of a large tuple
    elts = [ast.Name(id=f"x{i}") for i in range(500)] + [ast.Name(id="codeflash_output")] + [ast.Name(id=f"x{i}") for i in range(500, 999)]
    node = ast.Tuple(elts=elts)
    codeflash_output = visitor._is_codeflash_output_target(node) # 76.6μs -> 43.9μs (74.7% faster)

def test_large_tuple_without_codeflash_output(visitor):
    # Should return False if 'codeflash_output' is not present in a large tuple
    node = ast.Tuple(elts=[ast.Name(id=f"x{i}") for i in range(1000)])
    codeflash_output = visitor._is_codeflash_output_target(node) # 147μs -> 86.9μs (69.8% faster)

def test_large_nested_tuple_with_codeflash_output(visitor):
    # Should return True if 'codeflash_output' is deeply nested in a large structure
    inner = ast.Tuple(elts=[ast.Name(id=f"x{i}") for i in range(999)] + [ast.Name(id="codeflash_output")])
    outer = ast.Tuple(elts=[ast.Name(id="a"), inner, ast.Name(id="b")])
    codeflash_output = visitor._is_codeflash_output_target(outer) # 147μs -> 87.7μs (68.5% faster)

def test_large_nested_tuple_without_codeflash_output(visitor):
    # Should return False if 'codeflash_output' is not present in any nested structure
    inner = ast.Tuple(elts=[ast.Name(id=f"x{i}") for i in range(1000)])
    outer = ast.Tuple(elts=[ast.Name(id="a"), inner, ast.Name(id="b")])
    codeflash_output = visitor._is_codeflash_output_target(outer) # 148μs -> 88.2μs (67.9% faster)

def test_large_list_with_codeflash_output(visitor):
    # Should return True for large list containing 'codeflash_output'
    node = ast.List(elts=[ast.Name(id=f"x{i}") for i in range(500)] + [ast.Name(id="codeflash_output")] + [ast.Name(id=f"x{i}") for i in range(501, 1000)])
    codeflash_output = visitor._is_codeflash_output_target(node) # 75.8μs -> 42.8μs (77.1% faster)

def test_large_list_without_codeflash_output(visitor):
    # Should return False for large list without 'codeflash_output'
    node = ast.List(elts=[ast.Name(id=f"x{i}") for i in range(1000)])
    codeflash_output = visitor._is_codeflash_output_target(node) # 149μs -> 86.5μs (73.3% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr358-2025-06-21T00.08.50

Click to see suggested changes
Suggested change
def _is_codeflash_output_target(self, target: Union[ast.expr, list]) -> bool: # type: ignore[type-arg]
"""Check if the assignment target is the variable 'codeflash_output'."""
if isinstance(target, ast.Name):
return target.id == "codeflash_output"
if isinstance(target, (ast.Tuple, ast.List)):
# Handle tuple/list unpacking: a, codeflash_output, b = values
return any(self._is_codeflash_output_target(elt) for elt in target.elts)
if isinstance(target, (ast.Subscript, ast.Attribute)):
def _is_codeflash_output_target(self, target: Union[ast.expr, list]) -> bool:
"""Check if the assignment target is the variable 'codeflash_output'."""
t_type = type(target)
if t_type is ast.Name:
return target.id == "codeflash_output"
if t_type is ast.Tuple or t_type is ast.List:
# Handle tuple/list unpacking: a, codeflash_output, b = values
# Flatten elts (avoid generator, avoid recursion if not necessary):
elts = target.elts
for elt in elts:
if type(elt) is ast.Name:
if elt.id == "codeflash_output":
return True
elif type(elt) is ast.Tuple or type(elt) is ast.List:
# Recursively check nested tuples/lists, but avoid generator
if self._is_codeflash_output_target(elt):
return True
return False
if t_type is ast.Subscript or t_type is ast.Attribute:

# Not a simple variable assignment
return False
return False

def _record_assignment(self, node: ast.AST) -> None:
"""Record an assignment to codeflash_output."""
relative_line = node.lineno - 1 # type: ignore[attr-defined]
self.results.append(relative_line)

def visit_Assign(self, node: ast.Assign) -> None:
"""Visit assignment statements: codeflash_output = value."""
for target in node.targets:
if self._is_codeflash_output_target(target):
self._record_assignment(node)
break
self.generic_visit(node)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
"""Visit annotated assignments: codeflash_output: int = value."""
if self._is_codeflash_output_target(node.target):
self._record_assignment(node)
self.generic_visit(node)

def visit_AugAssign(self, node: ast.AugAssign) -> None:
"""Visit augmented assignments: codeflash_output += value."""
if self._is_codeflash_output_target(node.target):
self._record_assignment(node)
self.generic_visit(node)

def visit_NamedExpr(self, node: ast.NamedExpr) -> None:
"""Visit walrus operator: (codeflash_output := value)."""
if isinstance(node.target, ast.Name) and node.target.id == "codeflash_output":
self._record_assignment(node)
self.generic_visit(node)

def visit_For(self, node: ast.For) -> None:
"""Visit for loops: for codeflash_output in iterable."""
if self._is_codeflash_output_target(node.target):
self._record_assignment(node)
self.generic_visit(node)

def visit_comprehension(self, node: ast.comprehension) -> None:
"""Visit comprehensions: [x for codeflash_output in iterable]."""
if self._is_codeflash_output_target(node.target):
# Comprehensions don't have line numbers, so we skip recording
pass
self.generic_visit(node)

def visit_With(self, node: ast.With) -> None:
"""Visit with statements: with expr as codeflash_output."""
for item in node.items:
if item.optional_vars and self._is_codeflash_output_target(item.optional_vars):
self._record_assignment(node)
break
self.generic_visit(node)

def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
"""Visit except handlers: except Exception as codeflash_output."""
if node.name == "codeflash_output":
self._record_assignment(node)
self.generic_visit(node)


def find_codeflash_output_assignments(source_code: str) -> list[int]:
tree = ast.parse(source_code)
visitor = CfoVisitor(source_code)
visitor.visit(tree)
return visitor.results


def add_runtime_comments_to_generated_tests(
test_cfg: TestConfig,
generated_tests: GeneratedTestsList,
Expand All @@ -49,11 +145,15 @@ def add_runtime_comments_to_generated_tests(

# TODO: reduce for loops to one
class RuntimeCommentTransformer(cst.CSTTransformer):
def __init__(self, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
super().__init__()
self.test = test
self.context_stack: list[str] = []
self.tests_root = tests_root
self.rel_tests_root = rel_tests_root
self.module = module
self.cfo_locs: list[int] = []
self.cfo_idx_loc_to_look_at: int = -1

def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Track when we enter a class
Expand All @@ -65,6 +165,13 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
return updated_node

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
# convert function body to ast normalized string and find occurrences of codeflash_output
body_code = dedent(self.module.code_for_node(node.body))
normalized_body_code = ast.unparse(ast.parse(body_code))
self.cfo_locs = sorted(
find_codeflash_output_assignments(normalized_body_code)
) # sorted in order we will encounter them
self.cfo_idx_loc_to_look_at = -1
self.context_stack.append(node.name.value)

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
Expand All @@ -91,10 +198,12 @@ def leave_SimpleStatementLine(

if codeflash_assignment_found:
# Find matching test cases by looking for this test function name in the test results
self.cfo_idx_loc_to_look_at += 1
matching_original_times = []
matching_optimized_times = []
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
for invocation_id, runtimes in original_runtimes.items():
# get position here and match in if condition
qualified_name = (
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
if invocation_id.test_class_name
Expand All @@ -105,13 +214,19 @@ def leave_SimpleStatementLine(
.with_suffix(".py")
.relative_to(self.rel_tests_root)
)
if qualified_name == ".".join(self.context_stack) and rel_path in [
self.test.behavior_file_path.relative_to(self.tests_root),
self.test.perf_file_path.relative_to(self.tests_root),
]:
if (
qualified_name == ".".join(self.context_stack)
and rel_path
in [
self.test.behavior_file_path.relative_to(self.tests_root),
self.test.perf_file_path.relative_to(self.tests_root),
]
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
):
matching_original_times.extend(runtimes)

for invocation_id, runtimes in optimized_runtimes.items():
# get position here and match in if condition
qualified_name = (
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
if invocation_id.test_class_name
Expand All @@ -122,10 +237,15 @@ def leave_SimpleStatementLine(
.with_suffix(".py")
.relative_to(self.rel_tests_root)
)
if qualified_name == ".".join(self.context_stack) and rel_path in [
self.test.behavior_file_path.relative_to(self.tests_root),
self.test.perf_file_path.relative_to(self.tests_root),
]:
if (
qualified_name == ".".join(self.context_stack)
and rel_path
in [
self.test.behavior_file_path.relative_to(self.tests_root),
self.test.perf_file_path.relative_to(self.tests_root),
]
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
):
matching_optimized_times.extend(runtimes)

if matching_original_times and matching_optimized_times:
Expand Down Expand Up @@ -161,9 +281,8 @@ def leave_SimpleStatementLine(
try:
# Parse the test source code
tree = cst.parse_module(test.generated_original_test_source)

# Transform the tree to add runtime comments
transformer = RuntimeCommentTransformer(test, tests_root, rel_tests_root)
transformer = RuntimeCommentTransformer(tree, test, tests_root, rel_tests_root)
modified_tree = tree.visit(transformer)

# Convert back to source code
Expand Down
Loading
Loading