Skip to content

Commit 6641f3d

Browse files
committed
Update code_extractor.py
1 parent 9f0f98e commit 6641f3d

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4-
from pathlib import Path
5-
from typing import TYPE_CHECKING, Dict, Optional, Set
4+
from typing import TYPE_CHECKING, Optional
65

76
import libcst as cst
87
import libcst.matchers as m
@@ -11,23 +10,25 @@
1110
from libcst.helpers import calculate_module_and_package
1211

1312
from codeflash.cli_cmds.console import logger
14-
from codeflash.models.models import FunctionParent, FunctionSource
13+
from codeflash.models.models import FunctionParent
1514

1615
if TYPE_CHECKING:
16+
from pathlib import Path
17+
1718
from libcst.helpers import ModuleNameAndPackage
1819

1920
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
21+
from codeflash.models.models import FunctionSource
2022

21-
from typing import List
2223

2324

2425
class GlobalAssignmentCollector(cst.CSTVisitor):
2526
"""Collects all global assignment statements."""
2627

27-
def __init__(self):
28+
def __init__(self) -> None:
2829
super().__init__()
29-
self.assignments: Dict[str, cst.Assign] = {}
30-
self.assignment_order: List[str] = []
30+
self.assignments: dict[str, cst.Assign] = {}
31+
self.assignment_order: list[str] = []
3132
# Track scope depth to identify global assignments
3233
self.scope_depth = 0
3334
self.if_else_depth = 0
@@ -72,11 +73,11 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
7273
class GlobalAssignmentTransformer(cst.CSTTransformer):
7374
"""Transforms global assignments in the original file with those from the new file."""
7475

75-
def __init__(self, new_assignments: Dict[str, cst.Assign], new_assignment_order: List[str]):
76+
def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None:
7677
super().__init__()
7778
self.new_assignments = new_assignments
7879
self.new_assignment_order = new_assignment_order
79-
self.processed_assignments: Set[str] = set()
80+
self.processed_assignments: set[str] = set()
8081
self.scope_depth = 0
8182
self.if_else_depth = 0
8283

@@ -145,7 +146,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
145146
class GlobalStatementCollector(cst.CSTVisitor):
146147
"""Visitor that collects all global statements (excluding imports and functions/classes)."""
147148

148-
def __init__(self):
149+
def __init__(self) -> None:
149150
super().__init__()
150151
self.global_statements = []
151152
self.in_function_or_class = False
@@ -178,7 +179,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
178179
class LastImportFinder(cst.CSTVisitor):
179180
"""Finds the position of the last import statement in the module."""
180181

181-
def __init__(self):
182+
def __init__(self) -> None:
182183
super().__init__()
183184
self.last_import_line = 0
184185
self.current_line = 0
@@ -193,7 +194,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
193194
class ImportInserter(cst.CSTTransformer):
194195
"""Transformer that inserts global statements after the last import."""
195196

196-
def __init__(self, global_statements: List[cst.SimpleStatementLine], last_import_line: int):
197+
def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import_line: int) -> None:
197198
super().__init__()
198199
self.global_statements = global_statements
199200
self.last_import_line = last_import_line
@@ -208,7 +209,7 @@ def leave_SimpleStatementLine(
208209
# If we're right after the last import and haven't inserted yet
209210
if self.current_line == self.last_import_line and not self.inserted:
210211
self.inserted = True
211-
return cst.Module(body=[updated_node] + self.global_statements)
212+
return cst.Module(body=[updated_node, *self.global_statements])
212213

213214
return cst.Module(body=[updated_node])
214215

@@ -222,7 +223,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
222223
return updated_node
223224

224225

225-
def extract_global_statements(source_code: str) -> List[cst.SimpleStatementLine]:
226+
def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
226227
"""Extract global statements from source code."""
227228
module = cst.parse_module(source_code)
228229
collector = GlobalStatementCollector()
@@ -285,8 +286,7 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
285286
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
286287
transformed_module = original_module.visit(transformer)
287288

288-
dst_module_code = transformed_module.code
289-
return dst_module_code
289+
return transformed_module.code
290290

291291

292292
def add_needed_imports_from_module(

0 commit comments

Comments
 (0)