From e7de51a2dfd4906f0e30f29a673157d5c9e242ac Mon Sep 17 00:00:00 2001 From: ali Date: Sat, 23 Aug 2025 18:34:37 +0300 Subject: [PATCH 1/6] prevent duplicate global assignments when reverting helpers --- codeflash/code_utils/code_replacer.py | 8 +- .../context/unused_definition_remover.py | 1 + codeflash/optimization/function_optimizer.py | 5 +- tests/test_code_replacement.py | 250 ++++++++++++++++++ 4 files changed, 262 insertions(+), 2 deletions(-) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 740e578b6..712df8c30 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -412,11 +412,17 @@ def replace_function_definitions_in_module( module_abspath: Path, preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, + global_assignments_added_before: bool = False, # noqa: FBT001, FBT002 ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code) + new_code: str = replace_functions_and_add_imports( - add_global_assignments(code_to_apply, source_code), + # adding the global assignments before replacing the code, not after + # becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import + # and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet) + # this was added at https://github.com/codeflash-ai/codeflash/pull/448 + add_global_assignments(code_to_apply, source_code) if not global_assignments_added_before else source_code, function_names, code_to_apply, module_abspath, diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index cf57af031..749757315 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -537,6 +537,7 @@ def revert_unused_helper_functions( module_abspath=file_path, preexisting_objects=set(), # Empty set since we're reverting project_root_path=project_root, + global_assignments_added_before=True, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice. ) if reverted_code: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index ba1b79492..07c596412 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -820,7 +820,10 @@ def reformat_code_and_helpers( return new_code, new_helper_code def replace_function_and_helpers_with_optimized_code( - self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str + self, + code_context: CodeOptimizationContext, + optimized_code: CodeStringsMarkdown, + original_helper_code: dict[Path, str], ) -> bool: did_update = False read_writable_functions_by_file_path = defaultdict(set) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 26e6e915b..fa0ad5adb 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -3228,3 +3228,253 @@ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import as assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import + + +def test_test(): + root_dir = Path(__file__).parent.parent.resolve() + main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve() + + original_code = '''"""Chunking objects not specific to a particular chunking strategy.""" + +from __future__ import annotations + +import collections +import copy +from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast + +import regex +from typing_extensions import Self, TypeAlias + +from unstructured.common.html_table import HtmlCell, HtmlRow, HtmlTable +from unstructured.documents.elements import ( + CompositeElement, + ConsolidationStrategy, + Element, + ElementMetadata, + Table, + TableChunk, + Title, +) +from unstructured.utils import lazyproperty + +# ================================================================================================ +# MODEL +# ================================================================================================ + +CHUNK_MAX_CHARS_DEFAULT: int = 500 +"""Hard-max chunk-length when no explicit value specified in `max_characters` argument. + +Provided for reference only, for example so the ingest CLI can advertise the default value in its +UI. External chunking-related functions (e.g. in ingest or decorators) should use +`max_characters: int | None = None` and not apply this default themselves. Only +`ChunkingOptions.max_characters` should apply a default value. +""" + +CHUNK_MULTI_PAGE_DEFAULT: bool = True +"""When False, respect page-boundaries (no two elements from different page in same chunk). + +Only operative for "by_title" chunking strategy. +""" + +BoundaryPredicate: TypeAlias = Callable[[Element], bool] +"""Detects when element represents crossing a semantic boundary like section or page.""" + +TextAndHtml: TypeAlias = tuple[str, str] + +# ================================================================================================ +# PRE-CHUNKER +# ================================================================================================ + + +class PreChunker: + """Gathers sequential elements into pre-chunks as length constraints allow. + + The pre-chunker's responsibilities are: + + - **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on + either side of those boundaries into different sections. In this case, the primary indicator + of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a + semantic boundary when `multipage_sections` is `False`. + + - **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit + into sections as big as possible without exceeding the chunk window size. + + - **Minimize chunks that must be split mid-text.** Precompute the text length of each section + and only produce a section that exceeds the chunk window size when there is a single element + with text longer than that window. + + A Table element is placed into a section by itself. CheckBox elements are dropped. + + The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates + a new "section", hence the "by-title" designation. + """ + + def __init__(self, elements: Iterable[Element], opts: ChunkingOptions): + self._elements = elements + self._opts = opts + + @lazyproperty + def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]: + """The semantic-boundary detectors to be applied to break pre-chunks.""" + return self._opts.boundary_predicates + + def _is_in_new_semantic_unit(self, element: Element) -> bool: + """True when `element` begins a new semantic unit such as a section or page.""" + # -- all detectors need to be called to update state and avoid double counting + # -- boundaries that happen to coincide, like Table and new section on same element. + # -- Using `any()` would short-circuit on first True. + semantic_boundaries = [pred(element) for pred in self._boundary_predicates] + return any(semantic_boundaries) + +''' + main_file.write_text(original_code, encoding="utf-8") + optim_code = f'''```python:{main_file.relative_to(root_dir)} +# ================================================================================================ +# PRE-CHUNKER +# ================================================================================================ + +from __future__ import annotations +from typing import Iterable +from unstructured.documents.elements import Element +from unstructured.utils import lazyproperty + +class PreChunker: + def __init__(self, elements: Iterable[Element], opts: ChunkingOptions): + self._elements = elements + self._opts = opts + + @lazyproperty + def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]: + """The semantic-boundary detectors to be applied to break pre-chunks.""" + return self._opts.boundary_predicates + + def _is_in_new_semantic_unit(self, element: Element) -> bool: + """True when `element` begins a new semantic unit such as a section or page.""" + # Use generator expression for lower memory usage and avoid building intermediate list + for pred in self._boundary_predicates: + if pred(element): + return True + return False +``` +''' + + func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file) + test_config = TestConfig( + tests_root=root_dir / "tests/pytest", + tests_project_rootdir=root_dir, + project_root_path=root_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code + ) + + + new_code = main_file.read_text(encoding="utf-8") + main_file.unlink(missing_ok=True) + + expected = '''"""Chunking objects not specific to a particular chunking strategy.""" + +from __future__ import annotations + +import collections +import copy +from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast + +import regex +from typing_extensions import Self, TypeAlias + +from unstructured.common.html_table import HtmlCell, HtmlRow, HtmlTable +from unstructured.documents.elements import ( + CompositeElement, + ConsolidationStrategy, + Element, + ElementMetadata, + Table, + TableChunk, + Title, +) +from unstructured.utils import lazyproperty + +# ================================================================================================ +# MODEL +# ================================================================================================ + +CHUNK_MAX_CHARS_DEFAULT: int = 500 +"""Hard-max chunk-length when no explicit value specified in `max_characters` argument. + +Provided for reference only, for example so the ingest CLI can advertise the default value in its +UI. External chunking-related functions (e.g. in ingest or decorators) should use +`max_characters: int | None = None` and not apply this default themselves. Only +`ChunkingOptions.max_characters` should apply a default value. +""" + +CHUNK_MULTI_PAGE_DEFAULT: bool = True +"""When False, respect page-boundaries (no two elements from different page in same chunk). + +Only operative for "by_title" chunking strategy. +""" + +BoundaryPredicate: TypeAlias = Callable[[Element], bool] +"""Detects when element represents crossing a semantic boundary like section or page.""" + +TextAndHtml: TypeAlias = tuple[str, str] + +# ================================================================================================ +# PRE-CHUNKER +# ================================================================================================ + + +class PreChunker: + """Gathers sequential elements into pre-chunks as length constraints allow. + + The pre-chunker's responsibilities are: + + - **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on + either side of those boundaries into different sections. In this case, the primary indicator + of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a + semantic boundary when `multipage_sections` is `False`. + + - **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit + into sections as big as possible without exceeding the chunk window size. + + - **Minimize chunks that must be split mid-text.** Precompute the text length of each section + and only produce a section that exceeds the chunk window size when there is a single element + with text longer than that window. + + A Table element is placed into a section by itself. CheckBox elements are dropped. + + The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates + a new "section", hence the "by-title" designation. + """ + def __init__(self, elements: Iterable[Element], opts: ChunkingOptions): + self._elements = elements + self._opts = opts + + @lazyproperty + def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]: + """The semantic-boundary detectors to be applied to break pre-chunks.""" + return self._opts.boundary_predicates + + def _is_in_new_semantic_unit(self, element: Element) -> bool: + """True when `element` begins a new semantic unit such as a section or page.""" + # Use generator expression for lower memory usage and avoid building intermediate list + for pred in self._boundary_predicates: + if pred(element): + return True + return False + +''' + assert new_code == expected From 117e682835215d3dbebf51662b2731493464a087 Mon Sep 17 00:00:00 2001 From: ali Date: Sat, 23 Aug 2025 20:17:13 +0300 Subject: [PATCH 2/6] test: simplify --- tests/test_code_replacement.py | 62 ++-------------------------------- 1 file changed, 3 insertions(+), 59 deletions(-) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index fa0ad5adb..897c4a505 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -3230,7 +3230,7 @@ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import -def test_test(): +def test_duplicate_global_assignments_when_reverting_helpers(): root_dir = Path(__file__).parent.parent.resolve() main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve() @@ -3244,42 +3244,14 @@ def test_test(): import regex from typing_extensions import Self, TypeAlias - -from unstructured.common.html_table import HtmlCell, HtmlRow, HtmlTable -from unstructured.documents.elements import ( - CompositeElement, - ConsolidationStrategy, - Element, - ElementMetadata, - Table, - TableChunk, - Title, -) from unstructured.utils import lazyproperty +from unstructured.documents.elements import Element # ================================================================================================ # MODEL # ================================================================================================ CHUNK_MAX_CHARS_DEFAULT: int = 500 -"""Hard-max chunk-length when no explicit value specified in `max_characters` argument. - -Provided for reference only, for example so the ingest CLI can advertise the default value in its -UI. External chunking-related functions (e.g. in ingest or decorators) should use -`max_characters: int | None = None` and not apply this default themselves. Only -`ChunkingOptions.max_characters` should apply a default value. -""" - -CHUNK_MULTI_PAGE_DEFAULT: bool = True -"""When False, respect page-boundaries (no two elements from different page in same chunk). - -Only operative for "by_title" chunking strategy. -""" - -BoundaryPredicate: TypeAlias = Callable[[Element], bool] -"""Detects when element represents crossing a semantic boundary like section or page.""" - -TextAndHtml: TypeAlias = tuple[str, str] # ================================================================================================ # PRE-CHUNKER @@ -3395,42 +3367,14 @@ def _is_in_new_semantic_unit(self, element: Element) -> bool: import regex from typing_extensions import Self, TypeAlias - -from unstructured.common.html_table import HtmlCell, HtmlRow, HtmlTable -from unstructured.documents.elements import ( - CompositeElement, - ConsolidationStrategy, - Element, - ElementMetadata, - Table, - TableChunk, - Title, -) from unstructured.utils import lazyproperty +from unstructured.documents.elements import Element # ================================================================================================ # MODEL # ================================================================================================ CHUNK_MAX_CHARS_DEFAULT: int = 500 -"""Hard-max chunk-length when no explicit value specified in `max_characters` argument. - -Provided for reference only, for example so the ingest CLI can advertise the default value in its -UI. External chunking-related functions (e.g. in ingest or decorators) should use -`max_characters: int | None = None` and not apply this default themselves. Only -`ChunkingOptions.max_characters` should apply a default value. -""" - -CHUNK_MULTI_PAGE_DEFAULT: bool = True -"""When False, respect page-boundaries (no two elements from different page in same chunk). - -Only operative for "by_title" chunking strategy. -""" - -BoundaryPredicate: TypeAlias = Callable[[Element], bool] -"""Detects when element represents crossing a semantic boundary like section or page.""" - -TextAndHtml: TypeAlias = tuple[str, str] # ================================================================================================ # PRE-CHUNKER From 9c8256a122546cc85ab7234ebd2d989b7092fc9f Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 25 Aug 2025 00:44:12 +0300 Subject: [PATCH 3/6] prevent duplicates for new global statements --- codeflash/code_utils/code_extractor.py | 36 +++++++++++-------- codeflash/code_utils/code_replacer.py | 3 +- .../context/unused_definition_remover.py | 1 - tests/test_code_replacement.py | 6 ---- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 0f69bed7a..a6cab79b4 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -338,20 +338,28 @@ def delete___future___aliased_imports(module_code: str) -> str: def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: - non_assignment_global_statements = extract_global_statements(src_module_code) - - # Find the last import line in target - last_import_line = find_last_import_line(dst_module_code) - - # Parse the target code - target_module = cst.parse_module(dst_module_code) - - # Create transformer to insert non_assignment_global_statements - transformer = ImportInserter(non_assignment_global_statements, last_import_line) - # - # # Apply transformation - modified_module = target_module.visit(transformer) - dst_module_code = modified_module.code + new_added_global_statements = extract_global_statements(src_module_code) + existing_global_statements = extract_global_statements(dst_module_code) + + unique_global_statements = [ + stmt + for stmt in new_added_global_statements + if not any(stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements) + ] + + if unique_global_statements: + # Find the last import line in target + last_import_line = find_last_import_line(dst_module_code) + + # Parse the target code + target_module = cst.parse_module(dst_module_code) + + # Create transformer to insert new statements + transformer = ImportInserter(unique_global_statements, last_import_line) + # + # # Apply transformation + modified_module = target_module.visit(transformer) + dst_module_code = modified_module.code # Parse the code original_module = cst.parse_module(dst_module_code) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 712df8c30..c958b5adf 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -412,7 +412,6 @@ def replace_function_definitions_in_module( module_abspath: Path, preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, - global_assignments_added_before: bool = False, # noqa: FBT001, FBT002 ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code) @@ -422,7 +421,7 @@ def replace_function_definitions_in_module( # becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import # and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet) # this was added at https://github.com/codeflash-ai/codeflash/pull/448 - add_global_assignments(code_to_apply, source_code) if not global_assignments_added_before else source_code, + add_global_assignments(code_to_apply, source_code), function_names, code_to_apply, module_abspath, diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 749757315..cf57af031 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -537,7 +537,6 @@ def revert_unused_helper_functions( module_abspath=file_path, preexisting_objects=set(), # Empty set since we're reverting project_root_path=project_root, - global_assignments_added_before=True, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice. ) if reverted_code: diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 897c4a505..5b12e3a05 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1707,7 +1707,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") a=2 print("Hello world") def some_fn(): @@ -1783,7 +1782,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") print("Hello world") def some_fn(): a=np.zeros(10) @@ -1862,7 +1860,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") a=3 print("Hello world") def some_fn(): @@ -1940,7 +1937,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") a=2 print("Hello world") def some_fn(): @@ -2019,7 +2015,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") a=3 print("Hello world") def some_fn(): @@ -2104,7 +2099,6 @@ def new_function2(value): """ expected_code = """import numpy as np -print("Hello world") if 2<3: a=4 else: From 8107bce165cfe590c05c5052742fd8d75ca71aba Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Mon, 25 Aug 2025 18:50:37 +0000 Subject: [PATCH 4/6] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20function?= =?UTF-8?q?=20`add=5Fglobal=5Fassignments`=20by=2018%=20in=20PR=20#683=20(?= =?UTF-8?q?`fix/duplicate-global-assignments-when-reverting-helpers`)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The optimized code achieves a **17% speedup** by eliminating redundant CST parsing operations, which are the most expensive parts of the function according to the line profiler. **Key optimizations:** 1. **Eliminate duplicate parsing**: The original code parsed `src_module_code` and `dst_module_code` multiple times. The optimized version introduces `_extract_global_statements_once()` that parses each module only once and reuses the parsed CST objects throughout the function. 2. **Reuse parsed modules**: Instead of re-parsing `dst_module_code` after modifications, the optimized version conditionally reuses the already-parsed `dst_module` when no global statements need insertion, avoiding unnecessary `cst.parse_module()` calls. 3. **Early termination**: Added an early return when `new_collector.assignments` is empty, avoiding the expensive `GlobalAssignmentTransformer` creation and visitation when there's nothing to transform. 4. **Minor optimization in uniqueness check**: Added a fast-path identity check (`stmt is existing_stmt`) before the expensive `deep_equals()` comparison, though this has minimal impact. **Performance impact by test case type:** - **Empty/minimal cases**: Show the highest gains (59-88% faster) due to early termination optimizations - **Standard cases**: Achieve consistent 20-30% improvements from reduced parsing - **Large-scale tests**: Benefit significantly (18-23% faster) as parsing overhead scales with code size The optimization is most effective for workloads with moderate to large code files where CST parsing dominates the runtime, as evidenced by the original profiler showing 70%+ of time spent in `cst.parse_module()` and `module.visit()` operations. --- codeflash/code_utils/code_extractor.py | 65 ++++++++++++++++---------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 4c50d978f..5ea4c7d46 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -373,39 +373,46 @@ def delete___future___aliased_imports(module_code: str) -> str: def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: - new_added_global_statements = extract_global_statements(src_module_code) - existing_global_statements = extract_global_statements(dst_module_code) - - # make sure we don't have any staments applited multiple times in the global level. - unique_global_statements = [ - stmt - for stmt in new_added_global_statements - if not any(stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements) - ] + # Avoid repeat parses and visits + src_module, new_added_global_statements = _extract_global_statements_once(src_module_code) + dst_module, existing_global_statements = _extract_global_statements_once(dst_module_code) + + # Build a list of global statements which are not already present using more efficient membership test. + # Slightly optimized by making a set of (hashable deep identity) for comparison. + # However, since CST nodes are not hashable, continue using deep_equals but do NOT recompute for identical object references. + unique_global_statements = [] + for stmt in new_added_global_statements: + # Fast path: check by id + if any( + stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements + ): + continue + unique_global_statements.append(stmt) + mod_dst_code = dst_module_code + # Insert unique global statements if any if unique_global_statements: - # Find the last import line in target last_import_line = find_last_import_line(dst_module_code) - - # Parse the target code - target_module = cst.parse_module(dst_module_code) - - # Create transformer to insert new statements + # Reuse already-parsed dst_module transformer = ImportInserter(unique_global_statements, last_import_line) - # - # # Apply transformation - modified_module = target_module.visit(transformer) - dst_module_code = modified_module.code - - # Parse the code - original_module = cst.parse_module(dst_module_code) - new_module = cst.parse_module(src_module_code) + # Use visit inplace, don't parse again + modified_module = dst_module.visit(transformer) + mod_dst_code = modified_module.code + # Parse the code after insertion + original_module = cst.parse_module(mod_dst_code) + else: + # No new statements to insert, reuse already-parsed dst_module + original_module = dst_module + # Parse the src_module_code once only (already done above: src_module) # Collect assignments from the new file new_collector = GlobalAssignmentCollector() - new_module.visit(new_collector) + src_module.visit(new_collector) + # Only create transformer if there are assignments to insert/transform + if not new_collector.assignments: # nothing to transform + return mod_dst_code - # Transform the original file + # Transform the original destination module transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order) transformed_module = original_module.visit(transformer) @@ -644,3 +651,11 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),))) return preexisting_objects + + +def _extract_global_statements_once(source_code: str): + """Extract global statements once and return both module and statements (internal)""" + module = cst.parse_module(source_code) + collector = GlobalStatementCollector() + module.visit(collector) + return module, collector.global_statements From b18d2c9775d7aa0893025d46273b1ab4272d7772 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 26 Aug 2025 03:46:18 +0300 Subject: [PATCH 5/6] better name --- codeflash/code_utils/code_replacer.py | 4 ++-- codeflash/context/unused_definition_remover.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 712df8c30..e05c70922 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -412,7 +412,7 @@ def replace_function_definitions_in_module( module_abspath: Path, preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, - global_assignments_added_before: bool = False, # noqa: FBT001, FBT002 + should_add_global_assignments: bool = True, # noqa: FBT001, FBT002 ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code) @@ -422,7 +422,7 @@ def replace_function_definitions_in_module( # becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import # and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet) # this was added at https://github.com/codeflash-ai/codeflash/pull/448 - add_global_assignments(code_to_apply, source_code) if not global_assignments_added_before else source_code, + add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code, function_names, code_to_apply, module_abspath, diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 749757315..78ad56ddc 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -537,7 +537,7 @@ def revert_unused_helper_functions( module_abspath=file_path, preexisting_objects=set(), # Empty set since we're reverting project_root_path=project_root, - global_assignments_added_before=True, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice. + should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice. ) if reverted_code: From 64466626f419794267900ebb9285d1d0fd0fffc6 Mon Sep 17 00:00:00 2001 From: ali Date: Sat, 30 Aug 2025 07:23:01 +0300 Subject: [PATCH 6/6] cleanup --- codeflash/code_utils/code_extractor.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 5ea4c7d46..52cb80a41 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c return updated_node -def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]: +def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: """Extract global statements from source code.""" module = cst.parse_module(source_code) collector = GlobalStatementCollector() module.visit(collector) - return collector.global_statements + return module, collector.global_statements def find_last_import_line(target_code: str) -> int: @@ -373,16 +373,11 @@ def delete___future___aliased_imports(module_code: str) -> str: def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: - # Avoid repeat parses and visits - src_module, new_added_global_statements = _extract_global_statements_once(src_module_code) - dst_module, existing_global_statements = _extract_global_statements_once(dst_module_code) + src_module, new_added_global_statements = extract_global_statements(src_module_code) + dst_module, existing_global_statements = extract_global_statements(dst_module_code) - # Build a list of global statements which are not already present using more efficient membership test. - # Slightly optimized by making a set of (hashable deep identity) for comparison. - # However, since CST nodes are not hashable, continue using deep_equals but do NOT recompute for identical object references. unique_global_statements = [] for stmt in new_added_global_statements: - # Fast path: check by id if any( stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements ): @@ -651,11 +646,3 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),))) return preexisting_objects - - -def _extract_global_statements_once(source_code: str): - """Extract global statements once and return both module and statements (internal)""" - module = cst.parse_module(source_code) - collector = GlobalStatementCollector() - module.visit(collector) - return module, collector.global_statements