Skip to content

Commit 6446662

Browse files
cleanup
1 parent 8107bce commit 6446662

File tree

1 file changed

+4
-17
lines changed

1 file changed

+4
-17
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
335335
return updated_node
336336

337337

338-
def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
338+
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
339339
"""Extract global statements from source code."""
340340
module = cst.parse_module(source_code)
341341
collector = GlobalStatementCollector()
342342
module.visit(collector)
343-
return collector.global_statements
343+
return module, collector.global_statements
344344

345345

346346
def find_last_import_line(target_code: str) -> int:
@@ -373,16 +373,11 @@ def delete___future___aliased_imports(module_code: str) -> str:
373373

374374

375375
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
376-
# Avoid repeat parses and visits
377-
src_module, new_added_global_statements = _extract_global_statements_once(src_module_code)
378-
dst_module, existing_global_statements = _extract_global_statements_once(dst_module_code)
376+
src_module, new_added_global_statements = extract_global_statements(src_module_code)
377+
dst_module, existing_global_statements = extract_global_statements(dst_module_code)
379378

380-
# Build a list of global statements which are not already present using more efficient membership test.
381-
# Slightly optimized by making a set of (hashable deep identity) for comparison.
382-
# However, since CST nodes are not hashable, continue using deep_equals but do NOT recompute for identical object references.
383379
unique_global_statements = []
384380
for stmt in new_added_global_statements:
385-
# Fast path: check by id
386381
if any(
387382
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
388383
):
@@ -651,11 +646,3 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP
651646
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
652647
preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),)))
653648
return preexisting_objects
654-
655-
656-
def _extract_global_statements_once(source_code: str):
657-
"""Extract global statements once and return both module and statements (internal)"""
658-
module = cst.parse_module(source_code)
659-
collector = GlobalStatementCollector()
660-
module.visit(collector)
661-
return module, collector.global_statements

0 commit comments

Comments
 (0)