Skip to content

Commit 28596b7

Browse files
committed
refactoring to run only for code replacement before PR/behavior instead of code context extraction
1 parent 3da7b73 commit 28596b7

File tree

3 files changed

+64
-34
lines changed

3 files changed

+64
-34
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,15 +235,7 @@ def delete___future___aliased_imports(module_code: str) -> str:
235235
return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code
236236

237237

238-
def add_needed_imports_from_module(
239-
src_module_code: str,
240-
dst_module_code: str,
241-
src_path: Path,
242-
dst_path: Path,
243-
project_root: Path,
244-
helper_functions: list[FunctionSource] | None = None,
245-
helper_functions_fqn: set[str] | None = None,
246-
) -> str:
238+
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
247239
non_assignment_global_statements = extract_global_statements(src_module_code)
248240

249241
# Find the last import line in target
@@ -272,7 +264,18 @@ def add_needed_imports_from_module(
272264
transformed_module = original_module.visit(transformer)
273265

274266
dst_module_code = transformed_module.code
267+
return dst_module_code
268+
275269

270+
def add_needed_imports_from_module(
271+
src_module_code: str,
272+
dst_module_code: str,
273+
src_path: Path,
274+
dst_path: Path,
275+
project_root: Path,
276+
helper_functions: list[FunctionSource] | None = None,
277+
helper_functions_fqn: set[str] | None = None,
278+
) -> str:
276279
"""Add all needed and used source module code imports to the destination module code, and return it."""
277280
src_module_code = delete___future___aliased_imports(src_module_code)
278281
if not helper_functions_fqn:

codeflash/code_utils/code_replacer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import libcst as cst
99

1010
from codeflash.cli_cmds.console import logger
11-
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
11+
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, add_global_assignments
1212
from codeflash.models.models import FunctionParent
1313

1414
if TYPE_CHECKING:
@@ -220,7 +220,8 @@ def replace_function_definitions_in_module(
220220
)
221221
if is_zero_diff(source_code, new_code):
222222
return False
223-
module_abspath.write_text(new_code, encoding="utf8")
223+
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
224+
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
224225
return True
225226

226227

tests/test_code_replacement.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,8 @@ def __init__(self, name):
804804
self.name = name
805805
806806
def main_method(self):
807-
return HelperClass(self.name).helper_method()"""
807+
return HelperClass(self.name).helper_method()
808+
"""
808809
file_path = Path(__file__).resolve()
809810
func_top_optimize = FunctionToOptimize(
810811
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
@@ -1665,51 +1666,76 @@ def new_function2(value):
16651666
def test_global_reassignment() -> None:
16661667
original_code = """a=1
16671668
print("Hello world")
1669+
def some_fn():
1670+
print("did noting")
16681671
class NewClass:
16691672
def __init__(self, name):
16701673
self.name = name
16711674
def __call__(self, value):
16721675
return "I am still old"
16731676
def new_function2(value):
16741677
return cst.ensure_type(value, str)
1675-
"""
1676-
1677-
optim_code = """import numpy as np
1678+
"""
1679+
optimized_code = """import numpy as np
1680+
def some_fn():
1681+
a=np.zeros(10)
1682+
print("did something")
16781683
class NewClass:
16791684
def __init__(self, name):
16801685
self.name = name
16811686
def __call__(self, value):
1682-
w = np.array([1,2,3])
1683-
return "I am new"
1687+
return "I am still old"
16841688
def new_function2(value):
16851689
return cst.ensure_type(value, str)
16861690
a=2
16871691
print("Hello world")
1688-
"""
1689-
1690-
modified_code = """import numpy as np
1691-
1692+
"""
1693+
expected_code = """import numpy as np
16921694
print("Hello world")
1695+
16931696
a=2
16941697
print("Hello world")
1698+
def some_fn():
1699+
a=np.zeros(10)
1700+
print("did something")
16951701
class NewClass:
16961702
def __init__(self, name):
16971703
self.name = name
16981704
def __call__(self, value):
1699-
w = np.array([1,2,3])
1700-
return "I am new"
1705+
return "I am still old"
1706+
def new_function2(value):
1707+
return cst.ensure_type(value, str)
1708+
def __init__(self, name):
1709+
self.name = name
1710+
def __call__(self, value):
1711+
return "I am still old"
17011712
def new_function2(value):
17021713
return cst.ensure_type(value, str)
17031714
"""
1704-
1705-
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
1706-
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
1707-
new_code: str = replace_functions_and_add_imports(
1708-
source_code=original_code,
1709-
function_names=function_names,
1710-
optimized_code=optim_code,
1711-
module_abspath=Path(__file__).resolve(),
1712-
preexisting_objects=preexisting_objects,
1713-
project_root_path=Path(__file__).resolve().parent.resolve(),
1715+
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
1716+
code_path.write_text(original_code, encoding="utf-8")
1717+
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
1718+
project_root_path = (Path(__file__).parent / "..").resolve()
1719+
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
1720+
test_config = TestConfig(
1721+
tests_root=tests_root,
1722+
tests_project_rootdir=project_root_path,
1723+
project_root_path=project_root_path,
1724+
test_framework="pytest",
1725+
pytest_cmd="pytest",
17141726
)
1715-
assert new_code == modified_code
1727+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
1728+
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
1729+
original_helper_code: dict[Path, str] = {}
1730+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
1731+
for helper_function_path in helper_function_paths:
1732+
with helper_function_path.open(encoding="utf8") as f:
1733+
helper_code = f.read()
1734+
original_helper_code[helper_function_path] = helper_code
1735+
func_optimizer.args = Args()
1736+
func_optimizer.replace_function_and_helpers_with_optimized_code(
1737+
code_context=code_context, optimized_code=optimized_code
1738+
)
1739+
new_code = code_path.read_text(encoding="utf-8")
1740+
code_path.unlink(missing_ok=True)
1741+
assert new_code.rstrip() == expected_code.rstrip()

0 commit comments

Comments
 (0)