@@ -1695,6 +1695,159 @@ def new_function2(value):
16951695
16961696a=2
16971697print("Hello world")
1698+ def some_fn():
1699+ a=np.zeros(10)
1700+ print("did something")
1701+ class NewClass:
1702+ def __init__(self, name):
1703+ self.name = name
1704+ def __call__(self, value):
1705+ return "I am still old"
1706+ def new_function2(value):
1707+ return cst.ensure_type(value, str)
1708+ def __init__(self, name):
1709+ self.name = name
1710+ def __call__(self, value):
1711+ return "I am still old"
1712+ def new_function2(value):
1713+ return cst.ensure_type(value, str)
1714+ """
1715+ code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/global_var_original.py" ).resolve ()
1716+ code_path .write_text (original_code , encoding = "utf-8" )
1717+ tests_root = Path ("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/" )
1718+ project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
1719+ func = FunctionToOptimize (function_name = "some_fn" , parents = [], file_path = code_path )
1720+ test_config = TestConfig (
1721+ tests_root = tests_root ,
1722+ tests_project_rootdir = project_root_path ,
1723+ project_root_path = project_root_path ,
1724+ test_framework = "pytest" ,
1725+ pytest_cmd = "pytest" ,
1726+ )
1727+ func_optimizer = FunctionOptimizer (function_to_optimize = func , test_cfg = test_config )
1728+ code_context : CodeOptimizationContext = func_optimizer .get_code_optimization_context ().unwrap ()
1729+ original_helper_code : dict [Path , str ] = {}
1730+ helper_function_paths = {hf .file_path for hf in code_context .helper_functions }
1731+ for helper_function_path in helper_function_paths :
1732+ with helper_function_path .open (encoding = "utf8" ) as f :
1733+ helper_code = f .read ()
1734+ original_helper_code [helper_function_path ] = helper_code
1735+ func_optimizer .args = Args ()
1736+ func_optimizer .replace_function_and_helpers_with_optimized_code (
1737+ code_context = code_context , optimized_code = optimized_code
1738+ )
1739+ new_code = code_path .read_text (encoding = "utf-8" )
1740+ code_path .unlink (missing_ok = True )
1741+ assert new_code .rstrip () == expected_code .rstrip ()
1742+
1743+ original_code = """print("Hello world")
1744+ def some_fn():
1745+ print("did noting")
1746+ class NewClass:
1747+ def __init__(self, name):
1748+ self.name = name
1749+ def __call__(self, value):
1750+ return "I am still old"
1751+ def new_function2(value):
1752+ return cst.ensure_type(value, str)
1753+ a=1
1754+ """
1755+ optimized_code = """a=2
1756+ import numpy as np
1757+ def some_fn():
1758+ a=np.zeros(10)
1759+ print("did something")
1760+ class NewClass:
1761+ def __init__(self, name):
1762+ self.name = name
1763+ def __call__(self, value):
1764+ return "I am still old"
1765+ def new_function2(value):
1766+ return cst.ensure_type(value, str)
1767+ print("Hello world")
1768+ """
1769+ expected_code = """import numpy as np
1770+ print("Hello world")
1771+
1772+ print("Hello world")
1773+ def some_fn():
1774+ a=np.zeros(10)
1775+ print("did something")
1776+ class NewClass:
1777+ def __init__(self, name):
1778+ self.name = name
1779+ def __call__(self, value):
1780+ return "I am still old"
1781+ def new_function2(value):
1782+ return cst.ensure_type(value, str)
1783+ def __init__(self, name):
1784+ self.name = name
1785+ def __call__(self, value):
1786+ return "I am still old"
1787+ def new_function2(value):
1788+ return cst.ensure_type(value, str)
1789+ a=2
1790+ """
1791+ code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/global_var_original.py" ).resolve ()
1792+ code_path .write_text (original_code , encoding = "utf-8" )
1793+ tests_root = Path ("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/" )
1794+ project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
1795+ func = FunctionToOptimize (function_name = "some_fn" , parents = [], file_path = code_path )
1796+ test_config = TestConfig (
1797+ tests_root = tests_root ,
1798+ tests_project_rootdir = project_root_path ,
1799+ project_root_path = project_root_path ,
1800+ test_framework = "pytest" ,
1801+ pytest_cmd = "pytest" ,
1802+ )
1803+ func_optimizer = FunctionOptimizer (function_to_optimize = func , test_cfg = test_config )
1804+ code_context : CodeOptimizationContext = func_optimizer .get_code_optimization_context ().unwrap ()
1805+ original_helper_code : dict [Path , str ] = {}
1806+ helper_function_paths = {hf .file_path for hf in code_context .helper_functions }
1807+ for helper_function_path in helper_function_paths :
1808+ with helper_function_path .open (encoding = "utf8" ) as f :
1809+ helper_code = f .read ()
1810+ original_helper_code [helper_function_path ] = helper_code
1811+ func_optimizer .args = Args ()
1812+ func_optimizer .replace_function_and_helpers_with_optimized_code (
1813+ code_context = code_context , optimized_code = optimized_code
1814+ )
1815+ new_code = code_path .read_text (encoding = "utf-8" )
1816+ code_path .unlink (missing_ok = True )
1817+ assert new_code .rstrip () == expected_code .rstrip ()
1818+
1819+ original_code = """a=1
1820+ print("Hello world")
1821+ def some_fn():
1822+ print("did noting")
1823+ class NewClass:
1824+ def __init__(self, name):
1825+ self.name = name
1826+ def __call__(self, value):
1827+ return "I am still old"
1828+ def new_function2(value):
1829+ return cst.ensure_type(value, str)
1830+ """
1831+ optimized_code = """import numpy as np
1832+ a=2
1833+ def some_fn():
1834+ a=np.zeros(10)
1835+ print("did something")
1836+ class NewClass:
1837+ def __init__(self, name):
1838+ self.name = name
1839+ def __call__(self, value):
1840+ return "I am still old"
1841+ def new_function2(value):
1842+ return cst.ensure_type(value, str)
1843+ a=3
1844+ print("Hello world")
1845+ """
1846+ expected_code = """import numpy as np
1847+ print("Hello world")
1848+
1849+ a=3
1850+ print("Hello world")
16981851def some_fn():
16991852 a=np.zeros(10)
17001853 print("did something")
0 commit comments