@@ -272,6 +272,8 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
272272 if child .module is None :
273273 continue
274274 module = self .get_full_dotted_name (child .module )
275+ if isinstance (child .names , cst .ImportStar ):
276+ continue
275277 for alias in child .names :
276278 if isinstance (alias , cst .ImportAlias ):
277279 name = alias .name .value
@@ -403,6 +405,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
403405 return transformed_module .code
404406
405407
408+ def resolve_star_import (module_name : str , project_root : Path ) -> set [str ]:
409+ try :
410+ module_path = module_name .replace ("." , "/" )
411+ possible_paths = [project_root / f"{ module_path } .py" , project_root / f"{ module_path } /__init__.py" ]
412+
413+ module_file = None
414+ for path in possible_paths :
415+ if path .exists ():
416+ module_file = path
417+ break
418+
419+ if module_file is None :
420+ logger .warning (f"Could not find module file for { module_name } , skipping star import resolution" )
421+ return set ()
422+
423+ with module_file .open (encoding = "utf8" ) as f :
424+ module_code = f .read ()
425+
426+ tree = ast .parse (module_code )
427+
428+ all_names = None
429+ for node in ast .walk (tree ):
430+ if (
431+ isinstance (node , ast .Assign )
432+ and len (node .targets ) == 1
433+ and isinstance (node .targets [0 ], ast .Name )
434+ and node .targets [0 ].id == "__all__"
435+ ):
436+ if isinstance (node .value , (ast .List , ast .Tuple )):
437+ all_names = []
438+ for elt in node .value .elts :
439+ if isinstance (elt , ast .Constant ) and isinstance (elt .value , str ):
440+ all_names .append (elt .value )
441+ elif isinstance (elt , ast .Str ): # Python < 3.8 compatibility
442+ all_names .append (elt .s )
443+ break
444+
445+ if all_names is not None :
446+ return set (all_names )
447+
448+ public_names = set ()
449+ for node in tree .body :
450+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
451+ if not node .name .startswith ("_" ):
452+ public_names .add (node .name )
453+ elif isinstance (node , ast .Assign ):
454+ for target in node .targets :
455+ if isinstance (target , ast .Name ) and not target .id .startswith ("_" ):
456+ public_names .add (target .id )
457+ elif isinstance (node , ast .AnnAssign ):
458+ if isinstance (node .target , ast .Name ) and not node .target .id .startswith ("_" ):
459+ public_names .add (node .target .id )
460+ elif isinstance (node , ast .Import ) or (
461+ isinstance (node , ast .ImportFrom ) and not any (alias .name == "*" for alias in node .names )
462+ ):
463+ for alias in node .names :
464+ name = alias .asname or alias .name
465+ if not name .startswith ("_" ):
466+ public_names .add (name )
467+
468+ return public_names # noqa: TRY300
469+
470+ except Exception as e :
471+ logger .warning (f"Error resolving star import for { module_name } : { e } " )
472+ return set ()
473+
474+
406475def add_needed_imports_from_module (
407476 src_module_code : str ,
408477 dst_module_code : str ,
@@ -457,9 +526,23 @@ def add_needed_imports_from_module(
457526 f"{ mod } .{ obj } " in helper_functions_fqn or dst_context .full_module_name == mod # avoid circular deps
458527 ):
459528 continue # Skip adding imports for helper functions already in the context
460- if f"{ mod } .{ obj } " not in dotted_import_collector .imports :
461- AddImportsVisitor .add_needed_import (dst_context , mod , obj )
462- RemoveImportsVisitor .remove_unused_import (dst_context , mod , obj )
529+
530+ # Handle star imports by resolving them to actual symbol names
531+ if obj == "*" :
532+ resolved_symbols = resolve_star_import (mod , project_root )
533+ logger .debug (f"Resolved star import from { mod } : { resolved_symbols } " )
534+
535+ for symbol in resolved_symbols :
536+ if (
537+ f"{ mod } .{ symbol } " not in helper_functions_fqn
538+ and f"{ mod } .{ symbol } " not in dotted_import_collector .imports
539+ ):
540+ AddImportsVisitor .add_needed_import (dst_context , mod , symbol )
541+ RemoveImportsVisitor .remove_unused_import (dst_context , mod , symbol )
542+ else :
543+ if f"{ mod } .{ obj } " not in dotted_import_collector .imports :
544+ AddImportsVisitor .add_needed_import (dst_context , mod , obj )
545+ RemoveImportsVisitor .remove_unused_import (dst_context , mod , obj )
463546 except Exception as e :
464547 logger .exception (f"Error adding imports to destination module code: { e } " )
465548 return dst_module_code
0 commit comments