Skip to content
53 changes: 32 additions & 21 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
return updated_node


def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
"""Extract global statements from source code."""
module = cst.parse_module(source_code)
collector = GlobalStatementCollector()
module.visit(collector)
return collector.global_statements
return module, collector.global_statements


def find_last_import_line(target_code: str) -> int:
Expand Down Expand Up @@ -373,30 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:


def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
non_assignment_global_statements = extract_global_statements(src_module_code)
src_module, new_added_global_statements = extract_global_statements(src_module_code)
dst_module, existing_global_statements = extract_global_statements(dst_module_code)

# Find the last import line in target
last_import_line = find_last_import_line(dst_module_code)

# Parse the target code
target_module = cst.parse_module(dst_module_code)

# Create transformer to insert non_assignment_global_statements
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
#
# # Apply transformation
modified_module = target_module.visit(transformer)
dst_module_code = modified_module.code

# Parse the code
original_module = cst.parse_module(dst_module_code)
new_module = cst.parse_module(src_module_code)
unique_global_statements = []
for stmt in new_added_global_statements:
if any(
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
):
continue
unique_global_statements.append(stmt)

mod_dst_code = dst_module_code
# Insert unique global statements if any
if unique_global_statements:
last_import_line = find_last_import_line(dst_module_code)
# Reuse already-parsed dst_module
transformer = ImportInserter(unique_global_statements, last_import_line)
# Use visit inplace, don't parse again
modified_module = dst_module.visit(transformer)
mod_dst_code = modified_module.code
# Parse the code after insertion
original_module = cst.parse_module(mod_dst_code)
else:
# No new statements to insert, reuse already-parsed dst_module
original_module = dst_module

# Parse the src_module_code once only (already done above: src_module)
# Collect assignments from the new file
new_collector = GlobalAssignmentCollector()
new_module.visit(new_collector)
src_module.visit(new_collector)
# Only create transformer if there are assignments to insert/transform
if not new_collector.assignments: # nothing to transform
return mod_dst_code

# Transform the original file
# Transform the original destination module
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
transformed_module = original_module.visit(transformer)

Expand Down
8 changes: 7 additions & 1 deletion codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,17 @@ def replace_function_definitions_in_module(
module_abspath: Path,
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
project_root_path: Path,
should_add_global_assignments: bool = True, # noqa: FBT001, FBT002
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)

new_code: str = replace_functions_and_add_imports(
add_global_assignments(code_to_apply, source_code),
# adding the global assignments before replacing the code, not after
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,
function_names,
code_to_apply,
module_abspath,
Expand Down
1 change: 1 addition & 0 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ def revert_unused_helper_functions(
module_abspath=file_path,
preexisting_objects=set(), # Empty set since we're reverting
project_root_path=project_root,
should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
)

if reverted_code:
Expand Down
5 changes: 4 additions & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,10 @@ def reformat_code_and_helpers(
return new_code, new_helper_code

def replace_function_and_helpers_with_optimized_code(
self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str
self,
code_context: CodeOptimizationContext,
optimized_code: CodeStringsMarkdown,
original_helper_code: dict[Path, str],
) -> bool:
did_update = False
read_writable_functions_by_file_path = defaultdict(set)
Expand Down
160 changes: 154 additions & 6 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,6 @@ def new_function2(value):
"""
expected_code = """import numpy as np

print("Hello world")
a=2
print("Hello world")
def some_fn():
Expand Down Expand Up @@ -1783,7 +1782,6 @@ def new_function2(value):
"""
expected_code = """import numpy as np

print("Hello world")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are only adding the unique global statements, since the exact print statement was in both original and optimized code, we should get only one statement in the final code not two

print("Hello world")
def some_fn():
a=np.zeros(10)
Expand Down Expand Up @@ -1862,7 +1860,6 @@ def new_function2(value):
"""
expected_code = """import numpy as np

print("Hello world")
a=3
print("Hello world")
def some_fn():
Expand Down Expand Up @@ -1940,7 +1937,6 @@ def new_function2(value):
"""
expected_code = """import numpy as np

print("Hello world")
a=2
print("Hello world")
def some_fn():
Expand Down Expand Up @@ -2019,7 +2015,6 @@ def new_function2(value):
"""
expected_code = """import numpy as np

print("Hello world")
a=3
print("Hello world")
def some_fn():
Expand Down Expand Up @@ -2106,7 +2101,6 @@ def new_function2(value):

a = 6

print("Hello world")
if 2<3:
a=4
else:
Expand Down Expand Up @@ -3453,3 +3447,157 @@ def hydrate_input_text_actions_with_field_names(
main_file.unlink(missing_ok=True)

assert new_code == expected

def test_duplicate_global_assignments_when_reverting_helpers():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the unstructured bug: mohammedahmed18/unstructured#1.

root_dir = Path(__file__).parent.parent.resolve()
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()

original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
from __future__ import annotations
import collections
import copy
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
import regex
from typing_extensions import Self, TypeAlias
from unstructured.utils import lazyproperty
from unstructured.documents.elements import Element
# ================================================================================================
# MODEL
# ================================================================================================
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow.
The pre-chunker's responsibilities are:
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
semantic boundary when `multipage_sections` is `False`.
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
into sections as big as possible without exceeding the chunk window size.
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
and only produce a section that exceeds the chunk window size when there is a single element
with text longer than that window.
A Table element is placed into a section by itself. CheckBox elements are dropped.
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation.
"""
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# -- all detectors need to be called to update state and avoid double counting
# -- boundaries that happen to coincide, like Table and new section on same element.
# -- Using `any()` would short-circuit on first True.
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
return any(semantic_boundaries)
'''
main_file.write_text(original_code, encoding="utf-8")
optim_code = f'''```python:{main_file.relative_to(root_dir)}
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
from __future__ import annotations
from typing import Iterable
from unstructured.documents.elements import Element
from unstructured.utils import lazyproperty
class PreChunker:
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# Use generator expression for lower memory usage and avoid building intermediate list
for pred in self._boundary_predicates:
if pred(element):
return True
return False
```
'''

func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file)
test_config = TestConfig(
tests_root=root_dir / "tests/pytest",
tests_project_rootdir=root_dir,
project_root_path=root_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()

original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code

func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
)


new_code = main_file.read_text(encoding="utf-8")
main_file.unlink(missing_ok=True)

expected = '''"""Chunking objects not specific to a particular chunking strategy."""
from __future__ import annotations
import collections
import copy
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
import regex
from typing_extensions import Self, TypeAlias
from unstructured.utils import lazyproperty
from unstructured.documents.elements import Element
# ================================================================================================
# MODEL
# ================================================================================================
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow.
The pre-chunker's responsibilities are:
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
semantic boundary when `multipage_sections` is `False`.
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
into sections as big as possible without exceeding the chunk window size.
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
and only produce a section that exceeds the chunk window size when there is a single element
with text longer than that window.
A Table element is placed into a section by itself. CheckBox elements are dropped.
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation.
"""
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# Use generator expression for lower memory usage and avoid building intermediate list
for pred in self._boundary_predicates:
if pred(element):
return True
return False
'''
assert new_code == expected
Loading