diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index f35bcfd5a..52cb80a41 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c return updated_node -def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]: +def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: """Extract global statements from source code.""" module = cst.parse_module(source_code) collector = GlobalStatementCollector() module.visit(collector) - return collector.global_statements + return module, collector.global_statements def find_last_import_line(target_code: str) -> int: @@ -373,30 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str: def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: - non_assignment_global_statements = extract_global_statements(src_module_code) + src_module, new_added_global_statements = extract_global_statements(src_module_code) + dst_module, existing_global_statements = extract_global_statements(dst_module_code) - # Find the last import line in target - last_import_line = find_last_import_line(dst_module_code) - - # Parse the target code - target_module = cst.parse_module(dst_module_code) - - # Create transformer to insert non_assignment_global_statements - transformer = ImportInserter(non_assignment_global_statements, last_import_line) - # - # # Apply transformation - modified_module = target_module.visit(transformer) - dst_module_code = modified_module.code - - # Parse the code - original_module = cst.parse_module(dst_module_code) - new_module = cst.parse_module(src_module_code) + unique_global_statements = [] + for stmt in new_added_global_statements: + if any( + stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements + ): + continue + unique_global_statements.append(stmt) + + mod_dst_code = dst_module_code + # Insert unique global statements if any + if unique_global_statements: + last_import_line = find_last_import_line(dst_module_code) + # Reuse already-parsed dst_module + transformer = ImportInserter(unique_global_statements, last_import_line) + # Use visit inplace, don't parse again + modified_module = dst_module.visit(transformer) + mod_dst_code = modified_module.code + # Parse the code after insertion + original_module = cst.parse_module(mod_dst_code) + else: + # No new statements to insert, reuse already-parsed dst_module + original_module = dst_module + # Parse the src_module_code once only (already done above: src_module) # Collect assignments from the new file new_collector = GlobalAssignmentCollector() - new_module.visit(new_collector) + src_module.visit(new_collector) + # Only create transformer if there are assignments to insert/transform + if not new_collector.assignments: # nothing to transform + return mod_dst_code - # Transform the original file + # Transform the original destination module transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order) transformed_module = original_module.visit(transformer) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 740e578b6..e05c70922 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -412,11 +412,17 @@ def replace_function_definitions_in_module( module_abspath: Path, preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, + should_add_global_assignments: bool = True, # noqa: FBT001, FBT002 ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code) + new_code: str = replace_functions_and_add_imports( - add_global_assignments(code_to_apply, source_code), + # adding the global assignments before replacing the code, not after + # becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import + # and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet) + # this was added at https://github.com/codeflash-ai/codeflash/pull/448 + add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code, function_names, code_to_apply, module_abspath, diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index cf57af031..78ad56ddc 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -537,6 +537,7 @@ def revert_unused_helper_functions( module_abspath=file_path, preexisting_objects=set(), # Empty set since we're reverting project_root_path=project_root, + should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice. ) if reverted_code: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index c523dcbce..1ce83c639 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -820,7 +820,10 @@ def reformat_code_and_helpers( return new_code, new_helper_code def replace_function_and_helpers_with_optimized_code( - self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str + self, + code_context: CodeOptimizationContext, + optimized_code: CodeStringsMarkdown, + original_helper_code: dict[Path, str], ) -> bool: did_update = False read_writable_functions_by_file_path = defaultdict(set) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 405896087..f7bfaace3 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1707,7 +1707,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") a=2 print("Hello world") def some_fn(): @@ -1783,7 +1782,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") print("Hello world") def some_fn(): a=np.zeros(10) @@ -1862,7 +1860,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") a=3 print("Hello world") def some_fn(): @@ -1940,7 +1937,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") a=2 print("Hello world") def some_fn(): @@ -2019,7 +2015,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") a=3 print("Hello world") def some_fn(): @@ -2106,7 +2101,6 @@ def new_function2(value): a = 6 -print("Hello world") if 2<3: a=4 else: @@ -3453,3 +3447,157 @@ def hydrate_input_text_actions_with_field_names( main_file.unlink(missing_ok=True) assert new_code == expected + +def test_duplicate_global_assignments_when_reverting_helpers(): + root_dir = Path(__file__).parent.parent.resolve() + main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve() + + original_code = '''"""Chunking objects not specific to a particular chunking strategy.""" +from __future__ import annotations +import collections +import copy +from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast +import regex +from typing_extensions import Self, TypeAlias +from unstructured.utils import lazyproperty +from unstructured.documents.elements import Element +# ================================================================================================ +# MODEL +# ================================================================================================ +CHUNK_MAX_CHARS_DEFAULT: int = 500 +# ================================================================================================ +# PRE-CHUNKER +# ================================================================================================ +class PreChunker: + """Gathers sequential elements into pre-chunks as length constraints allow. + The pre-chunker's responsibilities are: + - **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on + either side of those boundaries into different sections. In this case, the primary indicator + of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a + semantic boundary when `multipage_sections` is `False`. + - **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit + into sections as big as possible without exceeding the chunk window size. + - **Minimize chunks that must be split mid-text.** Precompute the text length of each section + and only produce a section that exceeds the chunk window size when there is a single element + with text longer than that window. + A Table element is placed into a section by itself. CheckBox elements are dropped. + The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates + a new "section", hence the "by-title" designation. + """ + def __init__(self, elements: Iterable[Element], opts: ChunkingOptions): + self._elements = elements + self._opts = opts + @lazyproperty + def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]: + """The semantic-boundary detectors to be applied to break pre-chunks.""" + return self._opts.boundary_predicates + def _is_in_new_semantic_unit(self, element: Element) -> bool: + """True when `element` begins a new semantic unit such as a section or page.""" + # -- all detectors need to be called to update state and avoid double counting + # -- boundaries that happen to coincide, like Table and new section on same element. + # -- Using `any()` would short-circuit on first True. + semantic_boundaries = [pred(element) for pred in self._boundary_predicates] + return any(semantic_boundaries) +''' + main_file.write_text(original_code, encoding="utf-8") + optim_code = f'''```python:{main_file.relative_to(root_dir)} +# ================================================================================================ +# PRE-CHUNKER +# ================================================================================================ +from __future__ import annotations +from typing import Iterable +from unstructured.documents.elements import Element +from unstructured.utils import lazyproperty +class PreChunker: + def __init__(self, elements: Iterable[Element], opts: ChunkingOptions): + self._elements = elements + self._opts = opts + @lazyproperty + def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]: + """The semantic-boundary detectors to be applied to break pre-chunks.""" + return self._opts.boundary_predicates + def _is_in_new_semantic_unit(self, element: Element) -> bool: + """True when `element` begins a new semantic unit such as a section or page.""" + # Use generator expression for lower memory usage and avoid building intermediate list + for pred in self._boundary_predicates: + if pred(element): + return True + return False +``` +''' + + func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file) + test_config = TestConfig( + tests_root=root_dir / "tests/pytest", + tests_project_rootdir=root_dir, + project_root_path=root_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code + ) + + + new_code = main_file.read_text(encoding="utf-8") + main_file.unlink(missing_ok=True) + + expected = '''"""Chunking objects not specific to a particular chunking strategy.""" +from __future__ import annotations +import collections +import copy +from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast +import regex +from typing_extensions import Self, TypeAlias +from unstructured.utils import lazyproperty +from unstructured.documents.elements import Element +# ================================================================================================ +# MODEL +# ================================================================================================ +CHUNK_MAX_CHARS_DEFAULT: int = 500 +# ================================================================================================ +# PRE-CHUNKER +# ================================================================================================ +class PreChunker: + """Gathers sequential elements into pre-chunks as length constraints allow. + The pre-chunker's responsibilities are: + - **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on + either side of those boundaries into different sections. In this case, the primary indicator + of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a + semantic boundary when `multipage_sections` is `False`. + - **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit + into sections as big as possible without exceeding the chunk window size. + - **Minimize chunks that must be split mid-text.** Precompute the text length of each section + and only produce a section that exceeds the chunk window size when there is a single element + with text longer than that window. + A Table element is placed into a section by itself. CheckBox elements are dropped. + The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates + a new "section", hence the "by-title" designation. + """ + def __init__(self, elements: Iterable[Element], opts: ChunkingOptions): + self._elements = elements + self._opts = opts + @lazyproperty + def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]: + """The semantic-boundary detectors to be applied to break pre-chunks.""" + return self._opts.boundary_predicates + def _is_in_new_semantic_unit(self, element: Element) -> bool: + """True when `element` begins a new semantic unit such as a section or page.""" + # Use generator expression for lower memory usage and avoid building intermediate list + for pred in self._boundary_predicates: + if pred(element): + return True + return False +''' + assert new_code == expected \ No newline at end of file