@@ -1892,3 +1892,156 @@ def new_function2(value):
18921892 new_code = code_path .read_text (encoding = "utf-8" )
18931893 code_path .unlink (missing_ok = True )
18941894 assert new_code .rstrip () == expected_code .rstrip ()
1895+
1896+ original_code = """a=1
1897+ print("Hello world")
1898+ def some_fn():
1899+ print("did noting")
1900+ class NewClass:
1901+ def __init__(self, name):
1902+ self.name = name
1903+ def __call__(self, value):
1904+ return "I am still old"
1905+ def new_function2(value):
1906+ return cst.ensure_type(value, str)
1907+ """
1908+ optimized_code = """a=2
1909+ import numpy as np
1910+ def some_fn():
1911+ a=np.zeros(10)
1912+ print("did something")
1913+ class NewClass:
1914+ def __init__(self, name):
1915+ self.name = name
1916+ def __call__(self, value):
1917+ return "I am still old"
1918+ def new_function2(value):
1919+ return cst.ensure_type(value, str)
1920+ print("Hello world")
1921+ """
1922+ expected_code = """import numpy as np
1923+ print("Hello world")
1924+
1925+ a=2
1926+ print("Hello world")
1927+ def some_fn():
1928+ a=np.zeros(10)
1929+ print("did something")
1930+ class NewClass:
1931+ def __init__(self, name):
1932+ self.name = name
1933+ def __call__(self, value):
1934+ return "I am still old"
1935+ def new_function2(value):
1936+ return cst.ensure_type(value, str)
1937+ def __init__(self, name):
1938+ self.name = name
1939+ def __call__(self, value):
1940+ return "I am still old"
1941+ def new_function2(value):
1942+ return cst.ensure_type(value, str)
1943+ """
1944+ code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/global_var_original.py" ).resolve ()
1945+ code_path .write_text (original_code , encoding = "utf-8" )
1946+ tests_root = Path ("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/" )
1947+ project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
1948+ func = FunctionToOptimize (function_name = "some_fn" , parents = [], file_path = code_path )
1949+ test_config = TestConfig (
1950+ tests_root = tests_root ,
1951+ tests_project_rootdir = project_root_path ,
1952+ project_root_path = project_root_path ,
1953+ test_framework = "pytest" ,
1954+ pytest_cmd = "pytest" ,
1955+ )
1956+ func_optimizer = FunctionOptimizer (function_to_optimize = func , test_cfg = test_config )
1957+ code_context : CodeOptimizationContext = func_optimizer .get_code_optimization_context ().unwrap ()
1958+ original_helper_code : dict [Path , str ] = {}
1959+ helper_function_paths = {hf .file_path for hf in code_context .helper_functions }
1960+ for helper_function_path in helper_function_paths :
1961+ with helper_function_path .open (encoding = "utf8" ) as f :
1962+ helper_code = f .read ()
1963+ original_helper_code [helper_function_path ] = helper_code
1964+ func_optimizer .args = Args ()
1965+ func_optimizer .replace_function_and_helpers_with_optimized_code (
1966+ code_context = code_context , optimized_code = optimized_code
1967+ )
1968+ new_code = code_path .read_text (encoding = "utf-8" )
1969+ code_path .unlink (missing_ok = True )
1970+ assert new_code .rstrip () == expected_code .rstrip ()
1971+
1972+ original_code = """a=1
1973+ print("Hello world")
1974+ def some_fn():
1975+ print("did noting")
1976+ class NewClass:
1977+ def __init__(self, name):
1978+ self.name = name
1979+ def __call__(self, value):
1980+ return "I am still old"
1981+ def new_function2(value):
1982+ return cst.ensure_type(value, str)
1983+ """
1984+ optimized_code = """import numpy as np
1985+ a=2
1986+ def some_fn():
1987+ a=np.zeros(10)
1988+ print("did something")
1989+ class NewClass:
1990+ def __init__(self, name):
1991+ self.name = name
1992+ def __call__(self, value):
1993+ return "I am still old"
1994+ def new_function2(value):
1995+ return cst.ensure_type(value, str)
1996+ a=3
1997+ print("Hello world")
1998+ """
1999+ expected_code = """import numpy as np
2000+ print("Hello world")
2001+
2002+ a=3
2003+ print("Hello world")
2004+ def some_fn():
2005+ a=np.zeros(10)
2006+ print("did something")
2007+ class NewClass:
2008+ def __init__(self, name):
2009+ self.name = name
2010+ def __call__(self, value):
2011+ return "I am still old"
2012+ def new_function2(value):
2013+ return cst.ensure_type(value, str)
2014+ def __init__(self, name):
2015+ self.name = name
2016+ def __call__(self, value):
2017+ return "I am still old"
2018+ def new_function2(value):
2019+ return cst.ensure_type(value, str)
2020+ """
2021+ code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/global_var_original.py" ).resolve ()
2022+ code_path .write_text (original_code , encoding = "utf-8" )
2023+ tests_root = Path ("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/" )
2024+ project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
2025+ func = FunctionToOptimize (function_name = "some_fn" , parents = [], file_path = code_path )
2026+ test_config = TestConfig (
2027+ tests_root = tests_root ,
2028+ tests_project_rootdir = project_root_path ,
2029+ project_root_path = project_root_path ,
2030+ test_framework = "pytest" ,
2031+ pytest_cmd = "pytest" ,
2032+ )
2033+ func_optimizer = FunctionOptimizer (function_to_optimize = func , test_cfg = test_config )
2034+ code_context : CodeOptimizationContext = func_optimizer .get_code_optimization_context ().unwrap ()
2035+ original_helper_code : dict [Path , str ] = {}
2036+ helper_function_paths = {hf .file_path for hf in code_context .helper_functions }
2037+ for helper_function_path in helper_function_paths :
2038+ with helper_function_path .open (encoding = "utf8" ) as f :
2039+ helper_code = f .read ()
2040+ original_helper_code [helper_function_path ] = helper_code
2041+ func_optimizer .args = Args ()
2042+ func_optimizer .replace_function_and_helpers_with_optimized_code (
2043+ code_context = code_context , optimized_code = optimized_code
2044+ )
2045+ new_code = code_path .read_text (encoding = "utf-8" )
2046+ code_path .unlink (missing_ok = True )
2047+ assert new_code .rstrip () == expected_code .rstrip ()
0 commit comments