Skip to content

Commit 9c8256a

Browse files
prevent duplicates for new global statements
1 parent 117e682 commit 9c8256a

File tree

4 files changed

+23
-23
lines changed

4 files changed

+23
-23
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -338,20 +338,28 @@ def delete___future___aliased_imports(module_code: str) -> str:
338338

339339

340340
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
341-
non_assignment_global_statements = extract_global_statements(src_module_code)
342-
343-
# Find the last import line in target
344-
last_import_line = find_last_import_line(dst_module_code)
345-
346-
# Parse the target code
347-
target_module = cst.parse_module(dst_module_code)
348-
349-
# Create transformer to insert non_assignment_global_statements
350-
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
351-
#
352-
# # Apply transformation
353-
modified_module = target_module.visit(transformer)
354-
dst_module_code = modified_module.code
341+
new_added_global_statements = extract_global_statements(src_module_code)
342+
existing_global_statements = extract_global_statements(dst_module_code)
343+
344+
unique_global_statements = [
345+
stmt
346+
for stmt in new_added_global_statements
347+
if not any(stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements)
348+
]
349+
350+
if unique_global_statements:
351+
# Find the last import line in target
352+
last_import_line = find_last_import_line(dst_module_code)
353+
354+
# Parse the target code
355+
target_module = cst.parse_module(dst_module_code)
356+
357+
# Create transformer to insert new statements
358+
transformer = ImportInserter(unique_global_statements, last_import_line)
359+
#
360+
# # Apply transformation
361+
modified_module = target_module.visit(transformer)
362+
dst_module_code = modified_module.code
355363

356364
# Parse the code
357365
original_module = cst.parse_module(dst_module_code)

codeflash/code_utils/code_replacer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ def replace_function_definitions_in_module(
412412
module_abspath: Path,
413413
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
414414
project_root_path: Path,
415-
global_assignments_added_before: bool = False, # noqa: FBT001, FBT002
416415
) -> bool:
417416
source_code: str = module_abspath.read_text(encoding="utf8")
418417
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
@@ -422,7 +421,7 @@ def replace_function_definitions_in_module(
422421
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
423422
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
424423
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
425-
add_global_assignments(code_to_apply, source_code) if not global_assignments_added_before else source_code,
424+
add_global_assignments(code_to_apply, source_code),
426425
function_names,
427426
code_to_apply,
428427
module_abspath,

codeflash/context/unused_definition_remover.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,6 @@ def revert_unused_helper_functions(
537537
module_abspath=file_path,
538538
preexisting_objects=set(), # Empty set since we're reverting
539539
project_root_path=project_root,
540-
global_assignments_added_before=True, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
541540
)
542541

543542
if reverted_code:

tests/test_code_replacement.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,7 +1707,6 @@ def new_function2(value):
17071707
"""
17081708
expected_code = """import numpy as np
17091709
1710-
print("Hello world")
17111710
a=2
17121711
print("Hello world")
17131712
def some_fn():
@@ -1783,7 +1782,6 @@ def new_function2(value):
17831782
"""
17841783
expected_code = """import numpy as np
17851784
1786-
print("Hello world")
17871785
print("Hello world")
17881786
def some_fn():
17891787
a=np.zeros(10)
@@ -1862,7 +1860,6 @@ def new_function2(value):
18621860
"""
18631861
expected_code = """import numpy as np
18641862
1865-
print("Hello world")
18661863
a=3
18671864
print("Hello world")
18681865
def some_fn():
@@ -1940,7 +1937,6 @@ def new_function2(value):
19401937
"""
19411938
expected_code = """import numpy as np
19421939
1943-
print("Hello world")
19441940
a=2
19451941
print("Hello world")
19461942
def some_fn():
@@ -2019,7 +2015,6 @@ def new_function2(value):
20192015
"""
20202016
expected_code = """import numpy as np
20212017
2022-
print("Hello world")
20232018
a=3
20242019
print("Hello world")
20252020
def some_fn():
@@ -2104,7 +2099,6 @@ def new_function2(value):
21042099
"""
21052100
expected_code = """import numpy as np
21062101
2107-
print("Hello world")
21082102
if 2<3:
21092103
a=4
21102104
else:

0 commit comments

Comments
 (0)