Skip to content

Commit b7d01a8

Browse files
committed
more tests
1 parent 0cbd204 commit b7d01a8

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

tests/test_code_replacement.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,3 +1892,156 @@ def new_function2(value):
18921892
new_code = code_path.read_text(encoding="utf-8")
18931893
code_path.unlink(missing_ok=True)
18941894
assert new_code.rstrip() == expected_code.rstrip()
1895+
1896+
original_code = """a=1
1897+
print("Hello world")
1898+
def some_fn():
1899+
print("did noting")
1900+
class NewClass:
1901+
def __init__(self, name):
1902+
self.name = name
1903+
def __call__(self, value):
1904+
return "I am still old"
1905+
def new_function2(value):
1906+
return cst.ensure_type(value, str)
1907+
"""
1908+
optimized_code = """a=2
1909+
import numpy as np
1910+
def some_fn():
1911+
a=np.zeros(10)
1912+
print("did something")
1913+
class NewClass:
1914+
def __init__(self, name):
1915+
self.name = name
1916+
def __call__(self, value):
1917+
return "I am still old"
1918+
def new_function2(value):
1919+
return cst.ensure_type(value, str)
1920+
print("Hello world")
1921+
"""
1922+
expected_code = """import numpy as np
1923+
print("Hello world")
1924+
1925+
a=2
1926+
print("Hello world")
1927+
def some_fn():
1928+
a=np.zeros(10)
1929+
print("did something")
1930+
class NewClass:
1931+
def __init__(self, name):
1932+
self.name = name
1933+
def __call__(self, value):
1934+
return "I am still old"
1935+
def new_function2(value):
1936+
return cst.ensure_type(value, str)
1937+
def __init__(self, name):
1938+
self.name = name
1939+
def __call__(self, value):
1940+
return "I am still old"
1941+
def new_function2(value):
1942+
return cst.ensure_type(value, str)
1943+
"""
1944+
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
1945+
code_path.write_text(original_code, encoding="utf-8")
1946+
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
1947+
project_root_path = (Path(__file__).parent / "..").resolve()
1948+
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
1949+
test_config = TestConfig(
1950+
tests_root=tests_root,
1951+
tests_project_rootdir=project_root_path,
1952+
project_root_path=project_root_path,
1953+
test_framework="pytest",
1954+
pytest_cmd="pytest",
1955+
)
1956+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
1957+
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
1958+
original_helper_code: dict[Path, str] = {}
1959+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
1960+
for helper_function_path in helper_function_paths:
1961+
with helper_function_path.open(encoding="utf8") as f:
1962+
helper_code = f.read()
1963+
original_helper_code[helper_function_path] = helper_code
1964+
func_optimizer.args = Args()
1965+
func_optimizer.replace_function_and_helpers_with_optimized_code(
1966+
code_context=code_context, optimized_code=optimized_code
1967+
)
1968+
new_code = code_path.read_text(encoding="utf-8")
1969+
code_path.unlink(missing_ok=True)
1970+
assert new_code.rstrip() == expected_code.rstrip()
1971+
1972+
original_code = """a=1
1973+
print("Hello world")
1974+
def some_fn():
1975+
print("did noting")
1976+
class NewClass:
1977+
def __init__(self, name):
1978+
self.name = name
1979+
def __call__(self, value):
1980+
return "I am still old"
1981+
def new_function2(value):
1982+
return cst.ensure_type(value, str)
1983+
"""
1984+
optimized_code = """import numpy as np
1985+
a=2
1986+
def some_fn():
1987+
a=np.zeros(10)
1988+
print("did something")
1989+
class NewClass:
1990+
def __init__(self, name):
1991+
self.name = name
1992+
def __call__(self, value):
1993+
return "I am still old"
1994+
def new_function2(value):
1995+
return cst.ensure_type(value, str)
1996+
a=3
1997+
print("Hello world")
1998+
"""
1999+
expected_code = """import numpy as np
2000+
print("Hello world")
2001+
2002+
a=3
2003+
print("Hello world")
2004+
def some_fn():
2005+
a=np.zeros(10)
2006+
print("did something")
2007+
class NewClass:
2008+
def __init__(self, name):
2009+
self.name = name
2010+
def __call__(self, value):
2011+
return "I am still old"
2012+
def new_function2(value):
2013+
return cst.ensure_type(value, str)
2014+
def __init__(self, name):
2015+
self.name = name
2016+
def __call__(self, value):
2017+
return "I am still old"
2018+
def new_function2(value):
2019+
return cst.ensure_type(value, str)
2020+
"""
2021+
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
2022+
code_path.write_text(original_code, encoding="utf-8")
2023+
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
2024+
project_root_path = (Path(__file__).parent / "..").resolve()
2025+
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
2026+
test_config = TestConfig(
2027+
tests_root=tests_root,
2028+
tests_project_rootdir=project_root_path,
2029+
project_root_path=project_root_path,
2030+
test_framework="pytest",
2031+
pytest_cmd="pytest",
2032+
)
2033+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
2034+
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
2035+
original_helper_code: dict[Path, str] = {}
2036+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
2037+
for helper_function_path in helper_function_paths:
2038+
with helper_function_path.open(encoding="utf8") as f:
2039+
helper_code = f.read()
2040+
original_helper_code[helper_function_path] = helper_code
2041+
func_optimizer.args = Args()
2042+
func_optimizer.replace_function_and_helpers_with_optimized_code(
2043+
code_context=code_context, optimized_code=optimized_code
2044+
)
2045+
new_code = code_path.read_text(encoding="utf-8")
2046+
code_path.unlink(missing_ok=True)
2047+
assert new_code.rstrip() == expected_code.rstrip()

0 commit comments

Comments
 (0)