Skip to content

Commit b349a21

Browse files
Merge pull request #665 from codeflash-ai/fix/duplicate-imports
[FIX] Prevent cst duplicate imports
2 parents a2825f9 + 00924b6 commit b349a21

File tree

2 files changed

+260
-18
lines changed

2 files changed

+260
-18
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,19 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
195195
self.last_import_line = self.current_line
196196

197197

198-
class ConditionalImportCollector(cst.CSTVisitor):
199-
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
198+
class DottedImportCollector(cst.CSTVisitor):
199+
"""Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`.
200+
201+
Examples
202+
--------
203+
import os ==> "os"
204+
import dbt.adapters.factory ==> "dbt.adapters.factory"
205+
from pathlib import Path ==> "pathlib.Path"
206+
from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter"
207+
from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional"
208+
from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps"
209+
210+
"""
200211

201212
def __init__(self) -> None:
202213
self.imports: set[str] = set()
@@ -217,7 +228,10 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
217228
for alias in child.names:
218229
module = self.get_full_dotted_name(alias.name)
219230
asname = alias.asname.name.value if alias.asname else alias.name.value
220-
self.imports.add(module if module == asname else f"{module}.{asname}")
231+
if isinstance(asname, cst.Attribute):
232+
self.imports.add(module)
233+
else:
234+
self.imports.add(module if module == asname else f"{module}.{asname}")
221235

222236
elif isinstance(child, cst.ImportFrom):
223237
if child.module is None:
@@ -231,6 +245,7 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
231245

232246
def visit_Module(self, node: cst.Module) -> None:
233247
self.depth = 0
248+
self._collect_imports_from_block(node)
234249

235250
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
236251
self.depth += 1
@@ -388,45 +403,44 @@ def add_needed_imports_from_module(
388403
logger.error(f"Error parsing source module code: {e}")
389404
return dst_module_code
390405

391-
cond_import_collector = ConditionalImportCollector()
406+
dotted_import_collector = DottedImportCollector()
392407
try:
393408
parsed_dst_module = cst.parse_module(dst_module_code)
394-
parsed_dst_module.visit(cond_import_collector)
409+
parsed_dst_module.visit(dotted_import_collector)
395410
except cst.ParserSyntaxError as e:
396411
logger.exception(f"Syntax error in destination module code: {e}")
397412
return dst_module_code # Return the original code if there's a syntax error
398413

399414
try:
400415
for mod in gatherer.module_imports:
401-
if mod in cond_import_collector.imports:
402-
continue
403-
AddImportsVisitor.add_needed_import(dst_context, mod)
416+
if mod not in dotted_import_collector.imports:
417+
AddImportsVisitor.add_needed_import(dst_context, mod)
404418
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
405419
for mod, obj_seq in gatherer.object_mapping.items():
406420
for obj in obj_seq:
407421
if (
408422
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
409423
):
410424
continue # Skip adding imports for helper functions already in the context
411-
if f"{mod}.{obj}" in cond_import_collector.imports:
412-
continue
413-
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
425+
if f"{mod}.{obj}" not in dotted_import_collector.imports:
426+
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
414427
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
415428
except Exception as e:
416429
logger.exception(f"Error adding imports to destination module code: {e}")
417430
return dst_module_code
431+
418432
for mod, asname in gatherer.module_aliases.items():
419-
if f"{mod}.{asname}" in cond_import_collector.imports:
420-
continue
421-
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
433+
if f"{mod}.{asname}" not in dotted_import_collector.imports:
434+
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
422435
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
436+
423437
for mod, alias_pairs in gatherer.alias_mapping.items():
424438
for alias_pair in alias_pairs:
425439
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
426440
continue
427-
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
428-
continue
429-
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
441+
442+
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
443+
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
430444
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
431445

432446
try:

tests/test_add_needed_imports_from_module.py

Lines changed: 229 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22

3-
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
3+
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
4+
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
45

56

67
def test_add_needed_imports_from_module0() -> None:
@@ -121,3 +122,230 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
121122
project_root = Path("/home/roger/repos/codeflash")
122123
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
123124
assert new_module == expected
125+
126+
def test_duplicated_imports() -> None:
127+
optim_code = '''from dataclasses import dataclass
128+
from recce.adapter.base import BaseAdapter
129+
from typing import Dict, List, Optional
130+
131+
@dataclass
132+
class DbtAdapter(BaseAdapter):
133+
134+
def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]:
135+
manifest = self.curr_manifest if base is False else self.base_manifest
136+
137+
try:
138+
parent_map_source = manifest.parent_map
139+
except AttributeError:
140+
parent_map_source = manifest.to_dict()["parent_map"]
141+
142+
node_ids = set(nodes)
143+
parent_map = {}
144+
for k, parents in parent_map_source.items():
145+
if k not in node_ids:
146+
continue
147+
parent_map[k] = [parent for parent in parents if parent in node_ids]
148+
149+
return parent_map
150+
'''
151+
152+
original_code = '''import json
153+
import logging
154+
import os
155+
import uuid
156+
from contextlib import contextmanager
157+
from copy import deepcopy
158+
from dataclasses import dataclass, fields
159+
from errno import ENOENT
160+
from functools import lru_cache
161+
from pathlib import Path
162+
from typing import (
163+
Any,
164+
Callable,
165+
Dict,
166+
Iterator,
167+
List,
168+
Literal,
169+
Optional,
170+
Set,
171+
Tuple,
172+
Type,
173+
Union,
174+
)
175+
176+
from recce.event import log_performance
177+
from recce.exceptions import RecceException
178+
from recce.util.cll import CLLPerformanceTracking, cll
179+
from recce.util.lineage import (
180+
build_column_key,
181+
filter_dependency_maps,
182+
find_downstream,
183+
find_upstream,
184+
)
185+
from recce.util.perf_tracking import LineagePerfTracker
186+
187+
from ...tasks.profile import ProfileTask
188+
from ...util.breaking import BreakingPerformanceTracking, parse_change_category
189+
190+
try:
191+
import agate
192+
import dbt.adapters.factory
193+
from dbt.contracts.state import PreviousState
194+
except ImportError as e:
195+
print("Error: dbt module not found. Please install it by running:")
196+
print("pip install dbt-core dbt-<adapter>")
197+
raise e
198+
from watchdog.events import FileSystemEventHandler
199+
from watchdog.observers import Observer
200+
201+
from recce.adapter.base import BaseAdapter
202+
from recce.state import ArtifactsRoot
203+
204+
from ...models import RunType
205+
from ...models.types import (
206+
CllColumn,
207+
CllData,
208+
CllNode,
209+
LineageDiff,
210+
NodeChange,
211+
NodeDiff,
212+
)
213+
from ...tasks import (
214+
HistogramDiffTask,
215+
ProfileDiffTask,
216+
QueryBaseTask,
217+
QueryDiffTask,
218+
QueryTask,
219+
RowCountDiffTask,
220+
RowCountTask,
221+
Task,
222+
TopKDiffTask,
223+
ValueDiffDetailTask,
224+
ValueDiffTask,
225+
)
226+
from .dbt_version import DbtVersion
227+
228+
@dataclass
229+
class DbtAdapter(BaseAdapter):
230+
231+
def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]:
232+
manifest = self.curr_manifest if base is False else self.base_manifest
233+
manifest_dict = manifest.to_dict()
234+
235+
node_ids = nodes.keys()
236+
parent_map = {}
237+
for k, parents in manifest_dict["parent_map"].items():
238+
if k not in node_ids:
239+
continue
240+
parent_map[k] = [parent for parent in parents if parent in node_ids]
241+
242+
return parent_map
243+
'''
244+
expected = '''import json
245+
import logging
246+
import os
247+
import uuid
248+
from contextlib import contextmanager
249+
from copy import deepcopy
250+
from dataclasses import dataclass, fields
251+
from errno import ENOENT
252+
from functools import lru_cache
253+
from pathlib import Path
254+
from typing import (
255+
Any,
256+
Callable,
257+
Dict,
258+
Iterator,
259+
List,
260+
Literal,
261+
Optional,
262+
Set,
263+
Tuple,
264+
Type,
265+
Union,
266+
)
267+
268+
from recce.event import log_performance
269+
from recce.exceptions import RecceException
270+
from recce.util.cll import CLLPerformanceTracking, cll
271+
from recce.util.lineage import (
272+
build_column_key,
273+
filter_dependency_maps,
274+
find_downstream,
275+
find_upstream,
276+
)
277+
from recce.util.perf_tracking import LineagePerfTracker
278+
279+
from ...tasks.profile import ProfileTask
280+
from ...util.breaking import BreakingPerformanceTracking, parse_change_category
281+
282+
try:
283+
import agate
284+
import dbt.adapters.factory
285+
from dbt.contracts.state import PreviousState
286+
except ImportError as e:
287+
print("Error: dbt module not found. Please install it by running:")
288+
print("pip install dbt-core dbt-<adapter>")
289+
raise e
290+
from watchdog.events import FileSystemEventHandler
291+
from watchdog.observers import Observer
292+
293+
from recce.adapter.base import BaseAdapter
294+
from recce.state import ArtifactsRoot
295+
296+
from ...models import RunType
297+
from ...models.types import (
298+
CllColumn,
299+
CllData,
300+
CllNode,
301+
LineageDiff,
302+
NodeChange,
303+
NodeDiff,
304+
)
305+
from ...tasks import (
306+
HistogramDiffTask,
307+
ProfileDiffTask,
308+
QueryBaseTask,
309+
QueryDiffTask,
310+
QueryTask,
311+
RowCountDiffTask,
312+
RowCountTask,
313+
Task,
314+
TopKDiffTask,
315+
ValueDiffDetailTask,
316+
ValueDiffTask,
317+
)
318+
from .dbt_version import DbtVersion
319+
320+
@dataclass
321+
class DbtAdapter(BaseAdapter):
322+
323+
def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]:
324+
manifest = self.curr_manifest if base is False else self.base_manifest
325+
326+
try:
327+
parent_map_source = manifest.parent_map
328+
except AttributeError:
329+
parent_map_source = manifest.to_dict()["parent_map"]
330+
331+
node_ids = set(nodes)
332+
parent_map = {}
333+
for k, parents in parent_map_source.items():
334+
if k not in node_ids:
335+
continue
336+
parent_map[k] = [parent for parent in parents if parent in node_ids]
337+
338+
return parent_map
339+
'''
340+
341+
function_name: str = "DbtAdapter.build_parent_map"
342+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
343+
new_code: str = replace_functions_and_add_imports(
344+
source_code=original_code,
345+
function_names=[function_name],
346+
optimized_code=optim_code,
347+
module_abspath=Path(__file__).resolve(),
348+
preexisting_objects=preexisting_objects,
349+
project_root_path=Path(__file__).resolve().parent.resolve(),
350+
)
351+
assert new_code == expected

0 commit comments

Comments
 (0)