Skip to content

Commit 330bf91

Browse files
fix code replacement tests
1 parent 99cd9dc commit 330bf91

File tree

3 files changed

+35
-16
lines changed

3 files changed

+35
-16
lines changed

codeflash/models/models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,11 @@ def markdown(self) -> str:
169169
]
170170
)
171171

172+
def path_to_code_string(self) -> dict[str, str]:
173+
return {code_string.file_path: code_string.code for code_string in self.code_strings}
174+
172175
@staticmethod
173-
def from_str_with_markers(code_with_markers: str) -> list[CodeString]:
176+
def from_str_with_markers(code_with_markers: str) -> CodeStringsMarkdown:
174177
pattern = rf"{SPLITTER_MARKER}([^\n]+)\n"
175178
matches = list(re.finditer(pattern, code_with_markers))
176179

@@ -181,7 +184,7 @@ def from_str_with_markers(code_with_markers: str) -> list[CodeString]:
181184
file_path = match.group(1).strip()
182185
code = code_with_markers[start:end].lstrip("\n")
183186
results.append(CodeString(file_path=file_path, code=code))
184-
return results
187+
return CodeStringsMarkdown(code_strings=results)
185188

186189

187190
class CodeOptimizationContext(BaseModel):

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -621,18 +621,22 @@ def replace_function_and_helpers_with_optimized_code(
621621
read_writable_functions_by_file_path[self.function_to_optimize.file_path].add(
622622
self.function_to_optimize.qualified_name
623623
)
624-
code_strings = CodeStringsMarkdown.from_str_with_markers(optimized_code)
625-
optimized_code_dict = {code_string.file_path: code_string.code for code_string in code_strings}
626-
logger.debug(f"Optimized code: {optimized_code_dict}")
624+
file_to_code_context = CodeStringsMarkdown.from_str_with_markers(optimized_code).path_to_code_string()
625+
logger.debug(f"Optimized code: {file_to_code_context}")
627626
for helper_function in code_context.helper_functions:
628627
if helper_function.jedi_definition.type != "class":
629628
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
630629
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
631630
relative_module_path = module_abspath.relative_to(self.project_root)
632631
logger.debug(f"applying optimized code to: {relative_module_path}")
632+
633+
optimized_code = file_to_code_context.get(relative_module_path)
634+
if not optimized_code:
635+
msg = f"Optimized code not found for {relative_module_path}, existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'"
636+
raise ValueError(msg)
633637
did_update |= replace_function_definitions_in_module(
634638
function_names=list(qualified_names),
635-
optimized_code=optimized_code_dict.get(relative_module_path),
639+
optimized_code=optimized_code,
636640
module_abspath=module_abspath,
637641
preexisting_objects=code_context.preexisting_objects,
638642
project_root_path=self.project_root,

tests/test_code_replacement.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
replace_functions_in_file,
1414
)
1515
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
16-
from codeflash.models.models import CodeOptimizationContext, FunctionParent
16+
from codeflash.models.models import CodeOptimizationContext, FunctionParent, get_code_block_splitter
1717
from codeflash.optimization.function_optimizer import FunctionOptimizer
1818
from codeflash.verification.verification_utils import TestConfig
1919

@@ -41,11 +41,14 @@ class Args:
4141

4242

4343
def test_code_replacement_global_statements():
44-
optimized_code = """import numpy as np
44+
project_root = Path(__file__).parent.parent.resolve()
45+
code_path = (project_root / "code_to_optimize/bubble_sort_optimized.py").resolve()
46+
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(project_root))}
47+
import numpy as np
48+
4549
inconsequential_var = '123'
4650
def sorter(arr):
4751
return arr.sort()"""
48-
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_optimized.py").resolve()
4952
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text(
5053
encoding="utf-8"
5154
)
@@ -1666,6 +1669,9 @@ def new_function2(value):
16661669

16671670

16681671
def test_global_reassignment() -> None:
1672+
root_dir = Path(__file__).parent.parent.resolve()
1673+
code_path = (root_dir / "code_to_optimize/global_var_original.py").resolve()
1674+
16691675
original_code = """a=1
16701676
print("Hello world")
16711677
def some_fn():
@@ -1678,7 +1684,9 @@ def __call__(self, value):
16781684
def new_function2(value):
16791685
return cst.ensure_type(value, str)
16801686
"""
1681-
optimized_code = """import numpy as np
1687+
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
1688+
import numpy as np
1689+
16821690
def some_fn():
16831691
a=np.zeros(10)
16841692
print("did something")
@@ -1713,7 +1721,6 @@ def __call__(self, value):
17131721
return "I am still old"
17141722
def new_function2(value):
17151723
return cst.ensure_type(value, str)"""
1716-
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
17171724
code_path.write_text(original_code, encoding="utf-8")
17181725
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
17191726
project_root_path = (Path(__file__).parent / "..").resolve()
@@ -1753,7 +1760,8 @@ def new_function2(value):
17531760
return cst.ensure_type(value, str)
17541761
a=1
17551762
"""
1756-
optimized_code = """a=2
1763+
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
1764+
a=2
17571765
import numpy as np
17581766
def some_fn():
17591767
a=np.zeros(10)
@@ -1829,7 +1837,8 @@ def __call__(self, value):
18291837
def new_function2(value):
18301838
return cst.ensure_type(value, str)
18311839
"""
1832-
optimized_code = """import numpy as np
1840+
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
1841+
import numpy as np
18331842
a=2
18341843
def some_fn():
18351844
a=np.zeros(10)
@@ -1906,7 +1915,8 @@ def __call__(self, value):
19061915
def new_function2(value):
19071916
return cst.ensure_type(value, str)
19081917
"""
1909-
optimized_code = """a=2
1918+
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
1919+
a=2
19101920
import numpy as np
19111921
def some_fn():
19121922
a=np.zeros(10)
@@ -1982,7 +1992,8 @@ def __call__(self, value):
19821992
def new_function2(value):
19831993
return cst.ensure_type(value, str)
19841994
"""
1985-
optimized_code = """import numpy as np
1995+
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
1996+
import numpy as np
19861997
a=2
19871998
def some_fn():
19881999
a=np.zeros(10)
@@ -2062,7 +2073,8 @@ def __call__(self, value):
20622073
def new_function2(value):
20632074
return cst.ensure_type(value, str)
20642075
"""
2065-
optimized_code = """import numpy as np
2076+
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
2077+
import numpy as np
20662078
if 1<2:
20672079
a=2
20682080
else:

0 commit comments

Comments
 (0)