Skip to content

Commit 44d9229

Browse files
committed
assignments with if/else blocks are not modified
1 parent a8cf3ee commit 44d9229

File tree

2 files changed

+114
-2
lines changed

2 files changed

+114
-2
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self):
2929
self.assignment_order: List[str] = []
3030
# Track scope depth to identify global assignments
3131
self.scope_depth = 0
32+
self.if_else_depth = 0
3233

3334
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
3435
self.scope_depth += 1
@@ -44,9 +45,20 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
4445
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
4546
self.scope_depth -= 1
4647

48+
def visit_If(self, node: cst.If) -> Optional[bool]:
49+
self.if_else_depth += 1
50+
return True
51+
52+
def leave_If(self, original_node: cst.If) -> None:
53+
self.if_else_depth -= 1
54+
55+
def visit_Else(self, node: cst.Else) -> Optional[bool]:
56+
# Else blocks are already counted as part of the if statement
57+
return True
58+
4759
def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
4860
# Only process global assignments (not inside functions, classes, etc.)
49-
if self.scope_depth == 0: # We're at module level
61+
if self.scope_depth == 0 and self.if_else_depth == 0: # We're at module level
5062
for target in node.targets:
5163
if isinstance(target.target, cst.Name):
5264
name = target.target.value
@@ -65,6 +77,7 @@ def __init__(self, new_assignments: Dict[str, cst.Assign], new_assignment_order:
6577
self.new_assignment_order = new_assignment_order
6678
self.processed_assignments: Set[str] = set()
6779
self.scope_depth = 0
80+
self.if_else_depth = 0
6881

6982
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
7083
self.scope_depth += 1
@@ -80,8 +93,19 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
8093
self.scope_depth -= 1
8194
return updated_node
8295

96+
def visit_If(self, node: cst.If) -> None:
97+
self.if_else_depth += 1
98+
99+
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
100+
self.if_else_depth -= 1
101+
return updated_node
102+
103+
def visit_Else(self, node: cst.Else) -> None:
104+
# Else blocks are already counted as part of the if statement
105+
pass
106+
83107
def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode:
84-
if self.scope_depth > 0:
108+
if self.scope_depth > 0 or self.if_else_depth > 0:
85109
return updated_node
86110

87111
# Check if this is a global assignment we need to replace

tests/test_code_replacement.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,6 +2017,94 @@ def __call__(self, value):
20172017
return "I am still old"
20182018
def new_function2(value):
20192019
return cst.ensure_type(value, str)
2020+
"""
2021+
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
2022+
code_path.write_text(original_code, encoding="utf-8")
2023+
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
2024+
project_root_path = (Path(__file__).parent / "..").resolve()
2025+
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
2026+
test_config = TestConfig(
2027+
tests_root=tests_root,
2028+
tests_project_rootdir=project_root_path,
2029+
project_root_path=project_root_path,
2030+
test_framework="pytest",
2031+
pytest_cmd="pytest",
2032+
)
2033+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
2034+
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
2035+
original_helper_code: dict[Path, str] = {}
2036+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
2037+
for helper_function_path in helper_function_paths:
2038+
with helper_function_path.open(encoding="utf8") as f:
2039+
helper_code = f.read()
2040+
original_helper_code[helper_function_path] = helper_code
2041+
func_optimizer.args = Args()
2042+
func_optimizer.replace_function_and_helpers_with_optimized_code(
2043+
code_context=code_context, optimized_code=optimized_code
2044+
)
2045+
new_code = code_path.read_text(encoding="utf-8")
2046+
code_path.unlink(missing_ok=True)
2047+
assert new_code.rstrip() == expected_code.rstrip()
2048+
2049+
original_code = """if 2<3:
2050+
a=4
2051+
else:
2052+
a=5
2053+
print("Hello world")
2054+
def some_fn():
2055+
print("did noting")
2056+
class NewClass:
2057+
def __init__(self, name):
2058+
self.name = name
2059+
def __call__(self, value):
2060+
return "I am still old"
2061+
def new_function2(value):
2062+
return cst.ensure_type(value, str)
2063+
"""
2064+
optimized_code = """import numpy as np
2065+
if 1<2:
2066+
a=2
2067+
else:
2068+
a=3
2069+
a = 6
2070+
def some_fn():
2071+
a=np.zeros(10)
2072+
print("did something")
2073+
class NewClass:
2074+
def __init__(self, name):
2075+
self.name = name
2076+
def __call__(self, value):
2077+
return "I am still old"
2078+
def new_function2(value):
2079+
return cst.ensure_type(value, str)
2080+
print("Hello world")
2081+
"""
2082+
expected_code = """import numpy as np
2083+
print("Hello world")
2084+
2085+
if 2<3:
2086+
a=4
2087+
else:
2088+
a=5
2089+
print("Hello world")
2090+
def some_fn():
2091+
a=np.zeros(10)
2092+
print("did something")
2093+
class NewClass:
2094+
def __init__(self, name):
2095+
self.name = name
2096+
def __call__(self, value):
2097+
return "I am still old"
2098+
def new_function2(value):
2099+
return cst.ensure_type(value, str)
2100+
def __init__(self, name):
2101+
self.name = name
2102+
def __call__(self, value):
2103+
return "I am still old"
2104+
def new_function2(value):
2105+
return cst.ensure_type(value, str)
2106+
2107+
a = 6
20202108
"""
20212109
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
20222110
code_path.write_text(original_code, encoding="utf-8")

0 commit comments

Comments
 (0)