Skip to content

Commit 8107bce

Browse files
⚡️ Speed up function add_global_assignments by 18% in PR #683 (fix/duplicate-global-assignments-when-reverting-helpers)
The optimized code achieves a **17% speedup** by eliminating redundant CST parsing operations, which are the most expensive parts of the function according to the line profiler. **Key optimizations:** 1. **Eliminate duplicate parsing**: The original code parsed `src_module_code` and `dst_module_code` multiple times. The optimized version introduces `_extract_global_statements_once()` that parses each module only once and reuses the parsed CST objects throughout the function. 2. **Reuse parsed modules**: Instead of re-parsing `dst_module_code` after modifications, the optimized version conditionally reuses the already-parsed `dst_module` when no global statements need insertion, avoiding unnecessary `cst.parse_module()` calls. 3. **Early termination**: Added an early return when `new_collector.assignments` is empty, avoiding the expensive `GlobalAssignmentTransformer` creation and visitation when there's nothing to transform. 4. **Minor optimization in uniqueness check**: Added a fast-path identity check (`stmt is existing_stmt`) before the expensive `deep_equals()` comparison, though this has minimal impact. **Performance impact by test case type:** - **Empty/minimal cases**: Show the highest gains (59-88% faster) due to early termination optimizations - **Standard cases**: Achieve consistent 20-30% improvements from reduced parsing - **Large-scale tests**: Benefit significantly (18-23% faster) as parsing overhead scales with code size The optimization is most effective for workloads with moderate to large code files where CST parsing dominates the runtime, as evidenced by the original profiler showing 70%+ of time spent in `cst.parse_module()` and `module.visit()` operations.
1 parent 28f50cc commit 8107bce

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -373,39 +373,46 @@ def delete___future___aliased_imports(module_code: str) -> str:
373373

374374

375375
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
376-
new_added_global_statements = extract_global_statements(src_module_code)
377-
existing_global_statements = extract_global_statements(dst_module_code)
378-
379-
# make sure we don't have any staments applited multiple times in the global level.
380-
unique_global_statements = [
381-
stmt
382-
for stmt in new_added_global_statements
383-
if not any(stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements)
384-
]
376+
# Avoid repeat parses and visits
377+
src_module, new_added_global_statements = _extract_global_statements_once(src_module_code)
378+
dst_module, existing_global_statements = _extract_global_statements_once(dst_module_code)
379+
380+
# Build a list of global statements which are not already present using more efficient membership test.
381+
# Slightly optimized by making a set of (hashable deep identity) for comparison.
382+
# However, since CST nodes are not hashable, continue using deep_equals but do NOT recompute for identical object references.
383+
unique_global_statements = []
384+
for stmt in new_added_global_statements:
385+
# Fast path: check by id
386+
if any(
387+
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
388+
):
389+
continue
390+
unique_global_statements.append(stmt)
385391

392+
mod_dst_code = dst_module_code
393+
# Insert unique global statements if any
386394
if unique_global_statements:
387-
# Find the last import line in target
388395
last_import_line = find_last_import_line(dst_module_code)
389-
390-
# Parse the target code
391-
target_module = cst.parse_module(dst_module_code)
392-
393-
# Create transformer to insert new statements
396+
# Reuse already-parsed dst_module
394397
transformer = ImportInserter(unique_global_statements, last_import_line)
395-
#
396-
# # Apply transformation
397-
modified_module = target_module.visit(transformer)
398-
dst_module_code = modified_module.code
399-
400-
# Parse the code
401-
original_module = cst.parse_module(dst_module_code)
402-
new_module = cst.parse_module(src_module_code)
398+
# Use visit inplace, don't parse again
399+
modified_module = dst_module.visit(transformer)
400+
mod_dst_code = modified_module.code
401+
# Parse the code after insertion
402+
original_module = cst.parse_module(mod_dst_code)
403+
else:
404+
# No new statements to insert, reuse already-parsed dst_module
405+
original_module = dst_module
403406

407+
# Parse the src_module_code once only (already done above: src_module)
404408
# Collect assignments from the new file
405409
new_collector = GlobalAssignmentCollector()
406-
new_module.visit(new_collector)
410+
src_module.visit(new_collector)
411+
# Only create transformer if there are assignments to insert/transform
412+
if not new_collector.assignments: # nothing to transform
413+
return mod_dst_code
407414

408-
# Transform the original file
415+
# Transform the original destination module
409416
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
410417
transformed_module = original_module.visit(transformer)
411418

@@ -644,3 +651,11 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP
644651
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
645652
preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),)))
646653
return preexisting_objects
654+
655+
656+
def _extract_global_statements_once(source_code: str):
657+
"""Extract global statements once and return both module and statements (internal)"""
658+
module = cst.parse_module(source_code)
659+
collector = GlobalStatementCollector()
660+
module.visit(collector)
661+
return module, collector.global_statements

0 commit comments

Comments
 (0)