@@ -195,6 +195,64 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
195195 self .last_import_line = self .current_line
196196
197197
198+ class ConditionalImportCollector (cst .CSTVisitor ):
199+ """Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
200+
201+ def __init__ (self ) -> None :
202+ self .imports : set [str ] = set ()
203+ self .depth = 0 # top-level
204+
205+ def get_full_dotted_name (self , expr : cst .BaseExpression ) -> str :
206+ if isinstance (expr , cst .Name ):
207+ return expr .value
208+ if isinstance (expr , cst .Attribute ):
209+ return f"{ self .get_full_dotted_name (expr .value )} .{ expr .attr .value } "
210+ return ""
211+
212+ def _collect_imports_from_block (self , block : cst .IndentedBlock ) -> None :
213+ for statement in block .body :
214+ if isinstance (statement , cst .SimpleStatementLine ):
215+ for child in statement .body :
216+ if isinstance (child , cst .Import ):
217+ for alias in child .names :
218+ module = self .get_full_dotted_name (alias .name )
219+ asname = alias .asname .name .value if alias .asname else alias .name .value
220+ self .imports .add (module if module == asname else f"{ module } .{ asname } " )
221+
222+ elif isinstance (child , cst .ImportFrom ):
223+ if child .module is None :
224+ continue
225+ module = self .get_full_dotted_name (child .module )
226+ for alias in child .names :
227+ if isinstance (alias , cst .ImportAlias ):
228+ name = alias .name .value
229+ asname = alias .asname .name .value if alias .asname else name
230+ self .imports .add (f"{ module } .{ asname } " )
231+
232+ def visit_Module (self , node : cst .Module ) -> None :
233+ self .depth = 0
234+
235+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
236+ self .depth += 1
237+
238+ def leave_FunctionDef (self , node : cst .FunctionDef ) -> None :
239+ self .depth -= 1
240+
241+ def visit_ClassDef (self , node : cst .ClassDef ) -> None :
242+ self .depth += 1
243+
244+ def leave_ClassDef (self , node : cst .ClassDef ) -> None :
245+ self .depth -= 1
246+
247+ def visit_If (self , node : cst .If ) -> None :
248+ if self .depth == 0 :
249+ self ._collect_imports_from_block (node .body )
250+
251+ def visit_Try (self , node : cst .Try ) -> None :
252+ if self .depth == 0 :
253+ self ._collect_imports_from_block (node .body )
254+
255+
198256class ImportInserter (cst .CSTTransformer ):
199257 """Transformer that inserts global statements after the last import."""
200258
@@ -329,8 +387,19 @@ def add_needed_imports_from_module(
329387 except Exception as e :
330388 logger .error (f"Error parsing source module code: { e } " )
331389 return dst_module_code
390+
391+ cond_import_collector = ConditionalImportCollector ()
392+ try :
393+ parsed_dst_module = cst .parse_module (dst_module_code )
394+ parsed_dst_module .visit (cond_import_collector )
395+ except cst .ParserSyntaxError as e :
396+ logger .exception (f"Syntax error in destination module code: { e } " )
397+ return dst_module_code # Return the original code if there's a syntax error
398+
332399 try :
333400 for mod in gatherer .module_imports :
401+ if mod in cond_import_collector .imports :
402+ continue
334403 AddImportsVisitor .add_needed_import (dst_context , mod )
335404 RemoveImportsVisitor .remove_unused_import (dst_context , mod )
336405 for mod , obj_seq in gatherer .object_mapping .items ():
@@ -339,28 +408,29 @@ def add_needed_imports_from_module(
339408 f"{ mod } .{ obj } " in helper_functions_fqn or dst_context .full_module_name == mod # avoid circular deps
340409 ):
341410 continue # Skip adding imports for helper functions already in the context
411+ if f"{ mod } .{ obj } " in cond_import_collector .imports :
412+ continue
342413 AddImportsVisitor .add_needed_import (dst_context , mod , obj )
343414 RemoveImportsVisitor .remove_unused_import (dst_context , mod , obj )
344415 except Exception as e :
345416 logger .exception (f"Error adding imports to destination module code: { e } " )
346417 return dst_module_code
347418 for mod , asname in gatherer .module_aliases .items ():
419+ if f"{ mod } .{ asname } " in cond_import_collector .imports :
420+ continue
348421 AddImportsVisitor .add_needed_import (dst_context , mod , asname = asname )
349422 RemoveImportsVisitor .remove_unused_import (dst_context , mod , asname = asname )
350423 for mod , alias_pairs in gatherer .alias_mapping .items ():
351424 for alias_pair in alias_pairs :
352425 if f"{ mod } .{ alias_pair [0 ]} " in helper_functions_fqn :
353426 continue
427+ if f"{ mod } .{ alias_pair [1 ]} " in cond_import_collector .imports :
428+ continue
354429 AddImportsVisitor .add_needed_import (dst_context , mod , alias_pair [0 ], asname = alias_pair [1 ])
355430 RemoveImportsVisitor .remove_unused_import (dst_context , mod , alias_pair [0 ], asname = alias_pair [1 ])
356431
357432 try :
358- parsed_module = cst .parse_module (dst_module_code )
359- except cst .ParserSyntaxError as e :
360- logger .exception (f"Syntax error in destination module code: { e } " )
361- return dst_module_code # Return the original code if there's a syntax error
362- try :
363- transformed_module = AddImportsVisitor (dst_context ).transform_module (parsed_module )
433+ transformed_module = AddImportsVisitor (dst_context ).transform_module (parsed_dst_module )
364434 transformed_module = RemoveImportsVisitor (dst_context ).transform_module (transformed_module )
365435 return transformed_module .code .lstrip ("\n " )
366436 except Exception as e :
0 commit comments