1313 replace_functions_in_file ,
1414)
1515from codeflash .discovery .functions_to_optimize import FunctionToOptimize
16- from codeflash .models .models import CodeOptimizationContext , FunctionParent
16+ from codeflash .models .models import CodeOptimizationContext , FunctionParent , get_code_block_splitter
1717from codeflash .optimization .function_optimizer import FunctionOptimizer
1818from codeflash .verification .verification_utils import TestConfig
1919
@@ -41,11 +41,14 @@ class Args:
4141
4242
4343def test_code_replacement_global_statements ():
44- optimized_code = """import numpy as np
44+ project_root = Path (__file__ ).parent .parent .resolve ()
45+ code_path = (project_root / "code_to_optimize/bubble_sort_optimized.py" ).resolve ()
46+ optimized_code = f"""{ get_code_block_splitter (code_path .relative_to (project_root ))}
47+ import numpy as np
48+
4549inconsequential_var = '123'
4650def sorter(arr):
4751 return arr.sort()"""
48- code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/bubble_sort_optimized.py" ).resolve ()
4952 original_code_str = (Path (__file__ ).parent .resolve () / "../code_to_optimize/bubble_sort.py" ).read_text (
5053 encoding = "utf-8"
5154 )
@@ -1666,6 +1669,9 @@ def new_function2(value):
16661669
16671670
16681671def test_global_reassignment () -> None :
1672+ root_dir = Path (__file__ ).parent .parent .resolve ()
1673+ code_path = (root_dir / "code_to_optimize/global_var_original.py" ).resolve ()
1674+
16691675 original_code = """a=1
16701676print("Hello world")
16711677def some_fn():
@@ -1678,7 +1684,9 @@ def __call__(self, value):
16781684 def new_function2(value):
16791685 return cst.ensure_type(value, str)
16801686 """
1681- optimized_code = """import numpy as np
1687+ optimized_code = f"""{ get_code_block_splitter (code_path .relative_to (root_dir ))}
1688+ import numpy as np
1689+
16821690def some_fn():
16831691 a=np.zeros(10)
16841692 print("did something")
@@ -1713,7 +1721,6 @@ def __call__(self, value):
17131721 return "I am still old"
17141722 def new_function2(value):
17151723 return cst.ensure_type(value, str)"""
1716- code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/global_var_original.py" ).resolve ()
17171724 code_path .write_text (original_code , encoding = "utf-8" )
17181725 tests_root = Path ("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/" )
17191726 project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
@@ -1753,7 +1760,8 @@ def new_function2(value):
17531760 return cst.ensure_type(value, str)
17541761a=1
17551762"""
1756- optimized_code = """a=2
1763+ optimized_code = f"""{ get_code_block_splitter (code_path .relative_to (root_dir ))}
1764+ a=2
17571765import numpy as np
17581766def some_fn():
17591767 a=np.zeros(10)
@@ -1829,7 +1837,8 @@ def __call__(self, value):
18291837 def new_function2(value):
18301838 return cst.ensure_type(value, str)
18311839"""
1832- optimized_code = """import numpy as np
1840+ optimized_code = f"""{ get_code_block_splitter (code_path .relative_to (root_dir ))}
1841+ import numpy as np
18331842a=2
18341843def some_fn():
18351844 a=np.zeros(10)
@@ -1906,7 +1915,8 @@ def __call__(self, value):
19061915 def new_function2(value):
19071916 return cst.ensure_type(value, str)
19081917"""
1909- optimized_code = """a=2
1918+ optimized_code = f"""{ get_code_block_splitter (code_path .relative_to (root_dir ))}
1919+ a=2
19101920import numpy as np
19111921def some_fn():
19121922 a=np.zeros(10)
@@ -1982,7 +1992,8 @@ def __call__(self, value):
19821992 def new_function2(value):
19831993 return cst.ensure_type(value, str)
19841994"""
1985- optimized_code = """import numpy as np
1995+ optimized_code = f"""{ get_code_block_splitter (code_path .relative_to (root_dir ))}
1996+ import numpy as np
19861997a=2
19871998def some_fn():
19881999 a=np.zeros(10)
@@ -2062,7 +2073,8 @@ def __call__(self, value):
20622073 def new_function2(value):
20632074 return cst.ensure_type(value, str)
20642075"""
2065- optimized_code = """import numpy as np
2076+ optimized_code = f"""{ get_code_block_splitter (code_path .relative_to (root_dir ))}
2077+ import numpy as np
20662078if 1<2:
20672079 a=2
20682080else:
0 commit comments