@@ -272,6 +272,8 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
272
272
if child .module is None :
273
273
continue
274
274
module = self .get_full_dotted_name (child .module )
275
+ if isinstance (child .names , cst .ImportStar ):
276
+ continue
275
277
for alias in child .names :
276
278
if isinstance (alias , cst .ImportAlias ):
277
279
name = alias .name .value
@@ -403,6 +405,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
403
405
return transformed_module .code
404
406
405
407
408
+ def resolve_star_import (module_name : str , project_root : Path ) -> set [str ]:
409
+ try :
410
+ module_path = module_name .replace ("." , "/" )
411
+ possible_paths = [project_root / f"{ module_path } .py" , project_root / f"{ module_path } /__init__.py" ]
412
+
413
+ module_file = None
414
+ for path in possible_paths :
415
+ if path .exists ():
416
+ module_file = path
417
+ break
418
+
419
+ if module_file is None :
420
+ logger .warning (f"Could not find module file for { module_name } , skipping star import resolution" )
421
+ return set ()
422
+
423
+ with module_file .open (encoding = "utf8" ) as f :
424
+ module_code = f .read ()
425
+
426
+ tree = ast .parse (module_code )
427
+
428
+ all_names = None
429
+ for node in ast .walk (tree ):
430
+ if (
431
+ isinstance (node , ast .Assign )
432
+ and len (node .targets ) == 1
433
+ and isinstance (node .targets [0 ], ast .Name )
434
+ and node .targets [0 ].id == "__all__"
435
+ ):
436
+ if isinstance (node .value , (ast .List , ast .Tuple )):
437
+ all_names = []
438
+ for elt in node .value .elts :
439
+ if isinstance (elt , ast .Constant ) and isinstance (elt .value , str ):
440
+ all_names .append (elt .value )
441
+ elif isinstance (elt , ast .Str ): # Python < 3.8 compatibility
442
+ all_names .append (elt .s )
443
+ break
444
+
445
+ if all_names is not None :
446
+ return set (all_names )
447
+ else :
448
+ public_names = set ()
449
+ for node in tree .body :
450
+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
451
+ if not node .name .startswith ("_" ):
452
+ public_names .add (node .name )
453
+ elif isinstance (node , ast .Assign ):
454
+ for target in node .targets :
455
+ if isinstance (target , ast .Name ) and not target .id .startswith ("_" ):
456
+ public_names .add (target .id )
457
+ elif isinstance (node , ast .AnnAssign ):
458
+ if isinstance (node .target , ast .Name ) and not node .target .id .startswith ("_" ):
459
+ public_names .add (node .target .id )
460
+ elif isinstance (node , ast .Import ) or (
461
+ isinstance (node , ast .ImportFrom ) and not any (alias .name == "*" for alias in node .names )
462
+ ):
463
+ for alias in node .names :
464
+ name = alias .asname or alias .name
465
+ if not name .startswith ("_" ):
466
+ public_names .add (name )
467
+
468
+ return public_names
469
+
470
+ except Exception as e :
471
+ logger .warning (f"Error resolving star import for { module_name } : { e } " )
472
+ return set ()
473
+
474
+
406
475
def add_needed_imports_from_module (
407
476
src_module_code : str ,
408
477
dst_module_code : str ,
@@ -457,9 +526,23 @@ def add_needed_imports_from_module(
457
526
f"{ mod } .{ obj } " in helper_functions_fqn or dst_context .full_module_name == mod # avoid circular deps
458
527
):
459
528
continue # Skip adding imports for helper functions already in the context
460
- if f"{ mod } .{ obj } " not in dotted_import_collector .imports :
461
- AddImportsVisitor .add_needed_import (dst_context , mod , obj )
462
- RemoveImportsVisitor .remove_unused_import (dst_context , mod , obj )
529
+
530
+ # Handle star imports by resolving them to actual symbol names
531
+ if obj == "*" :
532
+ resolved_symbols = resolve_star_import (mod , project_root )
533
+ logger .debug (f"Resolved star import from { mod } : { resolved_symbols } " )
534
+
535
+ for symbol in resolved_symbols :
536
+ if (
537
+ f"{ mod } .{ symbol } " not in helper_functions_fqn
538
+ and f"{ mod } .{ symbol } " not in dotted_import_collector .imports
539
+ ):
540
+ AddImportsVisitor .add_needed_import (dst_context , mod , symbol )
541
+ RemoveImportsVisitor .remove_unused_import (dst_context , mod , symbol )
542
+ else :
543
+ if f"{ mod } .{ obj } " not in dotted_import_collector .imports :
544
+ AddImportsVisitor .add_needed_import (dst_context , mod , obj )
545
+ RemoveImportsVisitor .remove_unused_import (dst_context , mod , obj )
463
546
except Exception as e :
464
547
logger .exception (f"Error adding imports to destination module code: { e } " )
465
548
return dst_module_code
0 commit comments