Skip to content

Commit 0cbd204

Browse files
committed
more tests with different positions of global variables and multiple reassignments
1 parent 28596b7 commit 0cbd204

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

tests/test_code_replacement.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,159 @@ def new_function2(value):
16951695
16961696
a=2
16971697
print("Hello world")
1698+
def some_fn():
1699+
a=np.zeros(10)
1700+
print("did something")
1701+
class NewClass:
1702+
def __init__(self, name):
1703+
self.name = name
1704+
def __call__(self, value):
1705+
return "I am still old"
1706+
def new_function2(value):
1707+
return cst.ensure_type(value, str)
1708+
def __init__(self, name):
1709+
self.name = name
1710+
def __call__(self, value):
1711+
return "I am still old"
1712+
def new_function2(value):
1713+
return cst.ensure_type(value, str)
1714+
"""
1715+
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
1716+
code_path.write_text(original_code, encoding="utf-8")
1717+
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
1718+
project_root_path = (Path(__file__).parent / "..").resolve()
1719+
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
1720+
test_config = TestConfig(
1721+
tests_root=tests_root,
1722+
tests_project_rootdir=project_root_path,
1723+
project_root_path=project_root_path,
1724+
test_framework="pytest",
1725+
pytest_cmd="pytest",
1726+
)
1727+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
1728+
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
1729+
original_helper_code: dict[Path, str] = {}
1730+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
1731+
for helper_function_path in helper_function_paths:
1732+
with helper_function_path.open(encoding="utf8") as f:
1733+
helper_code = f.read()
1734+
original_helper_code[helper_function_path] = helper_code
1735+
func_optimizer.args = Args()
1736+
func_optimizer.replace_function_and_helpers_with_optimized_code(
1737+
code_context=code_context, optimized_code=optimized_code
1738+
)
1739+
new_code = code_path.read_text(encoding="utf-8")
1740+
code_path.unlink(missing_ok=True)
1741+
assert new_code.rstrip() == expected_code.rstrip()
1742+
1743+
original_code = """print("Hello world")
1744+
def some_fn():
1745+
print("did noting")
1746+
class NewClass:
1747+
def __init__(self, name):
1748+
self.name = name
1749+
def __call__(self, value):
1750+
return "I am still old"
1751+
def new_function2(value):
1752+
return cst.ensure_type(value, str)
1753+
a=1
1754+
"""
1755+
optimized_code = """a=2
1756+
import numpy as np
1757+
def some_fn():
1758+
a=np.zeros(10)
1759+
print("did something")
1760+
class NewClass:
1761+
def __init__(self, name):
1762+
self.name = name
1763+
def __call__(self, value):
1764+
return "I am still old"
1765+
def new_function2(value):
1766+
return cst.ensure_type(value, str)
1767+
print("Hello world")
1768+
"""
1769+
expected_code = """import numpy as np
1770+
print("Hello world")
1771+
1772+
print("Hello world")
1773+
def some_fn():
1774+
a=np.zeros(10)
1775+
print("did something")
1776+
class NewClass:
1777+
def __init__(self, name):
1778+
self.name = name
1779+
def __call__(self, value):
1780+
return "I am still old"
1781+
def new_function2(value):
1782+
return cst.ensure_type(value, str)
1783+
def __init__(self, name):
1784+
self.name = name
1785+
def __call__(self, value):
1786+
return "I am still old"
1787+
def new_function2(value):
1788+
return cst.ensure_type(value, str)
1789+
a=2
1790+
"""
1791+
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
1792+
code_path.write_text(original_code, encoding="utf-8")
1793+
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
1794+
project_root_path = (Path(__file__).parent / "..").resolve()
1795+
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
1796+
test_config = TestConfig(
1797+
tests_root=tests_root,
1798+
tests_project_rootdir=project_root_path,
1799+
project_root_path=project_root_path,
1800+
test_framework="pytest",
1801+
pytest_cmd="pytest",
1802+
)
1803+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
1804+
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
1805+
original_helper_code: dict[Path, str] = {}
1806+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
1807+
for helper_function_path in helper_function_paths:
1808+
with helper_function_path.open(encoding="utf8") as f:
1809+
helper_code = f.read()
1810+
original_helper_code[helper_function_path] = helper_code
1811+
func_optimizer.args = Args()
1812+
func_optimizer.replace_function_and_helpers_with_optimized_code(
1813+
code_context=code_context, optimized_code=optimized_code
1814+
)
1815+
new_code = code_path.read_text(encoding="utf-8")
1816+
code_path.unlink(missing_ok=True)
1817+
assert new_code.rstrip() == expected_code.rstrip()
1818+
1819+
original_code = """a=1
1820+
print("Hello world")
1821+
def some_fn():
1822+
print("did noting")
1823+
class NewClass:
1824+
def __init__(self, name):
1825+
self.name = name
1826+
def __call__(self, value):
1827+
return "I am still old"
1828+
def new_function2(value):
1829+
return cst.ensure_type(value, str)
1830+
"""
1831+
optimized_code = """import numpy as np
1832+
a=2
1833+
def some_fn():
1834+
a=np.zeros(10)
1835+
print("did something")
1836+
class NewClass:
1837+
def __init__(self, name):
1838+
self.name = name
1839+
def __call__(self, value):
1840+
return "I am still old"
1841+
def new_function2(value):
1842+
return cst.ensure_type(value, str)
1843+
a=3
1844+
print("Hello world")
1845+
"""
1846+
expected_code = """import numpy as np
1847+
print("Hello world")
1848+
1849+
a=3
1850+
print("Hello world")
16981851
def some_fn():
16991852
a=np.zeros(10)
17001853
print("did something")

0 commit comments

Comments
 (0)