@@ -195,6 +195,79 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
195195 self .last_import_line = self .current_line
196196
197197
198+ class DottedImportCollector (cst .CSTVisitor ):
199+ """Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`.
200+
201+ Examples
202+ --------
203+ import os ==> "os"
204+ import dbt.adapters.factory ==> "dbt.adapters.factory"
205+ from pathlib import Path ==> "pathlib.Path"
206+ from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter"
207+ from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional"
208+ from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps"
209+
210+ """
211+
212+ def __init__ (self ) -> None :
213+ self .imports : set [str ] = set ()
214+ self .depth = 0 # top-level
215+
216+ def get_full_dotted_name (self , expr : cst .BaseExpression ) -> str :
217+ if isinstance (expr , cst .Name ):
218+ return expr .value
219+ if isinstance (expr , cst .Attribute ):
220+ return f"{ self .get_full_dotted_name (expr .value )} .{ expr .attr .value } "
221+ return ""
222+
223+ def _collect_imports_from_block (self , block : cst .IndentedBlock ) -> None :
224+ for statement in block .body :
225+ if isinstance (statement , cst .SimpleStatementLine ):
226+ for child in statement .body :
227+ if isinstance (child , cst .Import ):
228+ for alias in child .names :
229+ module = self .get_full_dotted_name (alias .name )
230+ asname = alias .asname .name .value if alias .asname else alias .name .value
231+ if isinstance (asname , cst .Attribute ):
232+ self .imports .add (module )
233+ else :
234+ self .imports .add (module if module == asname else f"{ module } .{ asname } " )
235+
236+ elif isinstance (child , cst .ImportFrom ):
237+ if child .module is None :
238+ continue
239+ module = self .get_full_dotted_name (child .module )
240+ for alias in child .names :
241+ if isinstance (alias , cst .ImportAlias ):
242+ name = alias .name .value
243+ asname = alias .asname .name .value if alias .asname else name
244+ self .imports .add (f"{ module } .{ asname } " )
245+
246+ def visit_Module (self , node : cst .Module ) -> None :
247+ self .depth = 0
248+ self ._collect_imports_from_block (node )
249+
250+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
251+ self .depth += 1
252+
253+ def leave_FunctionDef (self , node : cst .FunctionDef ) -> None :
254+ self .depth -= 1
255+
256+ def visit_ClassDef (self , node : cst .ClassDef ) -> None :
257+ self .depth += 1
258+
259+ def leave_ClassDef (self , node : cst .ClassDef ) -> None :
260+ self .depth -= 1
261+
262+ def visit_If (self , node : cst .If ) -> None :
263+ if self .depth == 0 :
264+ self ._collect_imports_from_block (node .body )
265+
266+ def visit_Try (self , node : cst .Try ) -> None :
267+ if self .depth == 0 :
268+ self ._collect_imports_from_block (node .body )
269+
270+
198271class ImportInserter (cst .CSTTransformer ):
199272 """Transformer that inserts global statements after the last import."""
200273
@@ -329,38 +402,49 @@ def add_needed_imports_from_module(
329402 except Exception as e :
330403 logger .error (f"Error parsing source module code: { e } " )
331404 return dst_module_code
405+
406+ dotted_import_collector = DottedImportCollector ()
407+ try :
408+ parsed_dst_module = cst .parse_module (dst_module_code )
409+ parsed_dst_module .visit (dotted_import_collector )
410+ except cst .ParserSyntaxError as e :
411+ logger .exception (f"Syntax error in destination module code: { e } " )
412+ return dst_module_code # Return the original code if there's a syntax error
413+
332414 try :
333415 for mod in gatherer .module_imports :
334- AddImportsVisitor .add_needed_import (dst_context , mod )
416+ if mod not in dotted_import_collector .imports :
417+ AddImportsVisitor .add_needed_import (dst_context , mod )
335418 RemoveImportsVisitor .remove_unused_import (dst_context , mod )
336419 for mod , obj_seq in gatherer .object_mapping .items ():
337420 for obj in obj_seq :
338421 if (
339422 f"{ mod } .{ obj } " in helper_functions_fqn or dst_context .full_module_name == mod # avoid circular deps
340423 ):
341424 continue # Skip adding imports for helper functions already in the context
342- AddImportsVisitor .add_needed_import (dst_context , mod , obj )
425+ if f"{ mod } .{ obj } " not in dotted_import_collector .imports :
426+ AddImportsVisitor .add_needed_import (dst_context , mod , obj )
343427 RemoveImportsVisitor .remove_unused_import (dst_context , mod , obj )
344428 except Exception as e :
345429 logger .exception (f"Error adding imports to destination module code: { e } " )
346430 return dst_module_code
431+
347432 for mod , asname in gatherer .module_aliases .items ():
348- AddImportsVisitor .add_needed_import (dst_context , mod , asname = asname )
433+ if f"{ mod } .{ asname } " not in dotted_import_collector .imports :
434+ AddImportsVisitor .add_needed_import (dst_context , mod , asname = asname )
349435 RemoveImportsVisitor .remove_unused_import (dst_context , mod , asname = asname )
436+
350437 for mod , alias_pairs in gatherer .alias_mapping .items ():
351438 for alias_pair in alias_pairs :
352439 if f"{ mod } .{ alias_pair [0 ]} " in helper_functions_fqn :
353440 continue
354- AddImportsVisitor .add_needed_import (dst_context , mod , alias_pair [0 ], asname = alias_pair [1 ])
441+
442+ if f"{ mod } .{ alias_pair [1 ]} " not in dotted_import_collector .imports :
443+ AddImportsVisitor .add_needed_import (dst_context , mod , alias_pair [0 ], asname = alias_pair [1 ])
355444 RemoveImportsVisitor .remove_unused_import (dst_context , mod , alias_pair [0 ], asname = alias_pair [1 ])
356445
357446 try :
358- parsed_module = cst .parse_module (dst_module_code )
359- except cst .ParserSyntaxError as e :
360- logger .exception (f"Syntax error in destination module code: { e } " )
361- return dst_module_code # Return the original code if there's a syntax error
362- try :
363- transformed_module = AddImportsVisitor (dst_context ).transform_module (parsed_module )
447+ transformed_module = AddImportsVisitor (dst_context ).transform_module (parsed_dst_module )
364448 transformed_module = RemoveImportsVisitor (dst_context ).transform_module (transformed_module )
365449 return transformed_module .code .lstrip ("\n " )
366450 except Exception as e :
0 commit comments