@@ -2017,6 +2017,94 @@ def __call__(self, value):
20172017 return "I am still old"
20182018 def new_function2(value):
20192019 return cst.ensure_type(value, str)
2020+ """
2021+ code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/global_var_original.py" ).resolve ()
2022+ code_path .write_text (original_code , encoding = "utf-8" )
2023+ tests_root = Path ("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/" )
2024+ project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
2025+ func = FunctionToOptimize (function_name = "some_fn" , parents = [], file_path = code_path )
2026+ test_config = TestConfig (
2027+ tests_root = tests_root ,
2028+ tests_project_rootdir = project_root_path ,
2029+ project_root_path = project_root_path ,
2030+ test_framework = "pytest" ,
2031+ pytest_cmd = "pytest" ,
2032+ )
2033+ func_optimizer = FunctionOptimizer (function_to_optimize = func , test_cfg = test_config )
2034+ code_context : CodeOptimizationContext = func_optimizer .get_code_optimization_context ().unwrap ()
2035+ original_helper_code : dict [Path , str ] = {}
2036+ helper_function_paths = {hf .file_path for hf in code_context .helper_functions }
2037+ for helper_function_path in helper_function_paths :
2038+ with helper_function_path .open (encoding = "utf8" ) as f :
2039+ helper_code = f .read ()
2040+ original_helper_code [helper_function_path ] = helper_code
2041+ func_optimizer .args = Args ()
2042+ func_optimizer .replace_function_and_helpers_with_optimized_code (
2043+ code_context = code_context , optimized_code = optimized_code
2044+ )
2045+ new_code = code_path .read_text (encoding = "utf-8" )
2046+ code_path .unlink (missing_ok = True )
2047+ assert new_code .rstrip () == expected_code .rstrip ()
2048+
2049+ original_code = """if 2<3:
2050+ a=4
2051+ else:
2052+ a=5
2053+ print("Hello world")
2054+ def some_fn():
2055+ print("did noting")
2056+ class NewClass:
2057+ def __init__(self, name):
2058+ self.name = name
2059+ def __call__(self, value):
2060+ return "I am still old"
2061+ def new_function2(value):
2062+ return cst.ensure_type(value, str)
2063+ """
2064+ optimized_code = """import numpy as np
2065+ if 1<2:
2066+ a=2
2067+ else:
2068+ a=3
2069+ a = 6
2070+ def some_fn():
2071+ a=np.zeros(10)
2072+ print("did something")
2073+ class NewClass:
2074+ def __init__(self, name):
2075+ self.name = name
2076+ def __call__(self, value):
2077+ return "I am still old"
2078+ def new_function2(value):
2079+ return cst.ensure_type(value, str)
2080+ print("Hello world")
2081+ """
2082+ expected_code = """import numpy as np
2083+ print("Hello world")
2084+
2085+ if 2<3:
2086+ a=4
2087+ else:
2088+ a=5
2089+ print("Hello world")
2090+ def some_fn():
2091+ a=np.zeros(10)
2092+ print("did something")
2093+ class NewClass:
2094+ def __init__(self, name):
2095+ self.name = name
2096+ def __call__(self, value):
2097+ return "I am still old"
2098+ def new_function2(value):
2099+ return cst.ensure_type(value, str)
2100+ def __init__(self, name):
2101+ self.name = name
2102+ def __call__(self, value):
2103+ return "I am still old"
2104+ def new_function2(value):
2105+ return cst.ensure_type(value, str)
2106+
2107+ a = 6
20202108"""
20212109 code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/global_var_original.py" ).resolve ()
20222110 code_path .write_text (original_code , encoding = "utf-8" )
0 commit comments