@@ -804,7 +804,8 @@ def __init__(self, name):
804804 self.name = name
805805
806806 def main_method(self):
807- return HelperClass(self.name).helper_method()"""
807+ return HelperClass(self.name).helper_method()
808+ """
808809 file_path = Path (__file__ ).resolve ()
809810 func_top_optimize = FunctionToOptimize (
810811 function_name = "main_method" , file_path = file_path , parents = [FunctionParent ("MainClass" , "ClassDef" )]
@@ -1665,51 +1666,76 @@ def new_function2(value):
16651666def test_global_reassignment () -> None :
16661667 original_code = """a=1
16671668print("Hello world")
1669+ def some_fn():
1670+ print("did noting")
16681671class NewClass:
16691672 def __init__(self, name):
16701673 self.name = name
16711674 def __call__(self, value):
16721675 return "I am still old"
16731676 def new_function2(value):
16741677 return cst.ensure_type(value, str)
1675- """
1676-
1677- optim_code = """import numpy as np
1678+ """
1679+ optimized_code = """import numpy as np
1680+ def some_fn():
1681+ a=np.zeros(10)
1682+ print("did something")
16781683class NewClass:
16791684 def __init__(self, name):
16801685 self.name = name
16811686 def __call__(self, value):
1682- w = np.array([1,2,3])
1683- return "I am new"
1687+ return "I am still old"
16841688 def new_function2(value):
16851689 return cst.ensure_type(value, str)
16861690a=2
16871691print("Hello world")
1688- """
1689-
1690- modified_code = """import numpy as np
1691-
1692+ """
1693+ expected_code = """import numpy as np
16921694print("Hello world")
1695+
16931696a=2
16941697print("Hello world")
1698+ def some_fn():
1699+ a=np.zeros(10)
1700+ print("did something")
16951701class NewClass:
16961702 def __init__(self, name):
16971703 self.name = name
16981704 def __call__(self, value):
1699- w = np.array([1,2,3])
1700- return "I am new"
1705+ return "I am still old"
1706+ def new_function2(value):
1707+ return cst.ensure_type(value, str)
1708+ def __init__(self, name):
1709+ self.name = name
1710+ def __call__(self, value):
1711+ return "I am still old"
17011712 def new_function2(value):
17021713 return cst.ensure_type(value, str)
17031714"""
1704-
1705- function_names : list [str ] = ["NewClass.__init__" , "NewClass.__call__" , "NewClass.new_function2" ]
1706- preexisting_objects : set [tuple [str , tuple [FunctionParent , ...]]] = find_preexisting_objects (original_code )
1707- new_code : str = replace_functions_and_add_imports (
1708- source_code = original_code ,
1709- function_names = function_names ,
1710- optimized_code = optim_code ,
1711- module_abspath = Path (__file__ ).resolve (),
1712- preexisting_objects = preexisting_objects ,
1713- project_root_path = Path (__file__ ).resolve ().parent .resolve (),
1715+ code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/global_var_original.py" ).resolve ()
1716+ code_path .write_text (original_code , encoding = "utf-8" )
1717+ tests_root = Path ("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/" )
1718+ project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
1719+ func = FunctionToOptimize (function_name = "some_fn" , parents = [], file_path = code_path )
1720+ test_config = TestConfig (
1721+ tests_root = tests_root ,
1722+ tests_project_rootdir = project_root_path ,
1723+ project_root_path = project_root_path ,
1724+ test_framework = "pytest" ,
1725+ pytest_cmd = "pytest" ,
17141726 )
1715- assert new_code == modified_code
1727+ func_optimizer = FunctionOptimizer (function_to_optimize = func , test_cfg = test_config )
1728+ code_context : CodeOptimizationContext = func_optimizer .get_code_optimization_context ().unwrap ()
1729+ original_helper_code : dict [Path , str ] = {}
1730+ helper_function_paths = {hf .file_path for hf in code_context .helper_functions }
1731+ for helper_function_path in helper_function_paths :
1732+ with helper_function_path .open (encoding = "utf8" ) as f :
1733+ helper_code = f .read ()
1734+ original_helper_code [helper_function_path ] = helper_code
1735+ func_optimizer .args = Args ()
1736+ func_optimizer .replace_function_and_helpers_with_optimized_code (
1737+ code_context = code_context , optimized_code = optimized_code
1738+ )
1739+ new_code = code_path .read_text (encoding = "utf-8" )
1740+ code_path .unlink (missing_ok = True )
1741+ assert new_code .rstrip () == expected_code .rstrip ()
0 commit comments