@@ -786,25 +786,20 @@ def visit_func_def(self, o: FuncDef) -> None:
786786 elif o .name in KNOWN_MAGIC_METHODS_RETURN_TYPES :
787787 retname = KNOWN_MAGIC_METHODS_RETURN_TYPES [o .name ]
788788 elif has_yield_expression (o ) or has_yield_from_expression (o ):
789- self .add_typing_import ("Generator" )
789+ generator_name = self .add_typing_import ("Generator" )
790790 yield_name = "None"
791791 send_name = "None"
792792 return_name = "None"
793793 if has_yield_from_expression (o ):
794- self .add_typing_import ("Incomplete" )
795- yield_name = send_name = self .typing_name ("Incomplete" )
794+ yield_name = send_name = self .add_typing_import ("Incomplete" )
796795 else :
797796 for expr , in_assignment in all_yield_expressions (o ):
798797 if expr .expr is not None and not self .is_none_expr (expr .expr ):
799- self .add_typing_import ("Incomplete" )
800- yield_name = self .typing_name ("Incomplete" )
798+ yield_name = self .add_typing_import ("Incomplete" )
801799 if in_assignment :
802- self .add_typing_import ("Incomplete" )
803- send_name = self .typing_name ("Incomplete" )
800+ send_name = self .add_typing_import ("Incomplete" )
804801 if has_return_statement (o ):
805- self .add_typing_import ("Incomplete" )
806- return_name = self .typing_name ("Incomplete" )
807- generator_name = self .typing_name ("Generator" )
802+ return_name = self .add_typing_import ("Incomplete" )
808803 retname = f"{ generator_name } [{ yield_name } , { send_name } , { return_name } ]"
809804 elif not has_return_statement (o ) and o .abstract_status == NOT_ABSTRACT :
810805 retname = "None"
@@ -965,21 +960,19 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
965960 nt_fields = self ._get_namedtuple_fields (base )
966961 assert isinstance (base .args [0 ], StrExpr )
967962 typename = base .args [0 ].value
968- if nt_fields is not None :
969- fields_str = ", " .join (f"({ f !r} , { t } )" for f , t in nt_fields )
970- namedtuple_name = self .typing_name ("NamedTuple" )
971- base_types .append (f"{ namedtuple_name } ({ typename !r} , [{ fields_str } ])" )
972- self .add_typing_import ("NamedTuple" )
973- else :
963+ if nt_fields is None :
974964 # Invalid namedtuple() call, cannot determine fields
975- base_types .append (self .typing_name ("Incomplete" ))
965+ base_types .append (self .add_typing_import ("Incomplete" ))
966+ continue
967+ fields_str = ", " .join (f"({ f !r} , { t } )" for f , t in nt_fields )
968+ namedtuple_name = self .add_typing_import ("NamedTuple" )
969+ base_types .append (f"{ namedtuple_name } ({ typename !r} , [{ fields_str } ])" )
976970 elif self .is_typed_namedtuple (base ):
977971 base_types .append (base .accept (p ))
978972 else :
979973 # At this point, we don't know what the base class is, so we
980974 # just use Incomplete as the base class.
981- base_types .append (self .typing_name ("Incomplete" ))
982- self .add_typing_import ("Incomplete" )
975+ base_types .append (self .add_typing_import ("Incomplete" ))
983976 for name , value in cdef .keywords .items ():
984977 if name == "metaclass" :
985978 continue # handled separately
@@ -1059,9 +1052,9 @@ def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None
10591052 field_names .append (field .value )
10601053 else :
10611054 return None # Invalid namedtuple fields type
1062- if field_names :
1063- self . add_typing_import ( "Incomplete" )
1064- incomplete = self .typing_name ("Incomplete" )
1055+ if not field_names :
1056+ return []
1057+ incomplete = self .add_typing_import ("Incomplete" )
10651058 return [(field_name , incomplete ) for field_name in field_names ]
10661059 elif self .is_typed_namedtuple (call ):
10671060 fields_arg = call .args [1 ]
@@ -1092,8 +1085,7 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
10921085 if fields is None :
10931086 self .annotate_as_incomplete (lvalue )
10941087 return
1095- self .add_typing_import ("NamedTuple" )
1096- bases = self .typing_name ("NamedTuple" )
1088+ bases = self .add_typing_import ("NamedTuple" )
10971089 # TODO: Add support for generic NamedTuples. Requires `Generic` as base class.
10981090 class_def = f"{ self ._indent } class { lvalue .name } ({ bases } ):"
10991091 if len (fields ) == 0 :
@@ -1143,14 +1135,13 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
11431135 total = arg
11441136 else :
11451137 items .append ((arg_name , arg ))
1146- self .add_typing_import ("TypedDict" )
1138+ bases = self .add_typing_import ("TypedDict" )
11471139 p = AliasPrinter (self )
11481140 if any (not key .isidentifier () or keyword .iskeyword (key ) for key , _ in items ):
11491141 # Keep the call syntax if there are non-identifier or reserved keyword keys.
11501142 self .add (f"{ self ._indent } { lvalue .name } = { rvalue .accept (p )} \n " )
11511143 self ._state = VAR
11521144 else :
1153- bases = self .typing_name ("TypedDict" )
11541145 # TODO: Add support for generic TypedDicts. Requires `Generic` as base class.
11551146 if total is not None :
11561147 bases += f", total={ total .accept (p )} "
@@ -1167,8 +1158,7 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
11671158 self ._state = CLASS
11681159
11691160 def annotate_as_incomplete (self , lvalue : NameExpr ) -> None :
1170- self .add_typing_import ("Incomplete" )
1171- self .add (f"{ self ._indent } { lvalue .name } : { self .typing_name ('Incomplete' )} \n " )
1161+ self .add (f"{ self ._indent } { lvalue .name } : { self .add_typing_import ('Incomplete' )} \n " )
11721162 self ._state = VAR
11731163
11741164 def is_alias_expression (self , expr : Expression , top_level : bool = True ) -> bool :
@@ -1384,13 +1374,14 @@ def typing_name(self, name: str) -> str:
13841374 else :
13851375 return name
13861376
1387- def add_typing_import (self , name : str ) -> None :
1377+ def add_typing_import (self , name : str ) -> str :
13881378 """Add a name to be imported for typing, unless it's imported already.
13891379
13901380 The import will be internal to the stub.
13911381 """
13921382 name = self .typing_name (name )
13931383 self .import_tracker .require_name (name )
1384+ return name
13941385
13951386 def add_import_line (self , line : str ) -> None :
13961387 """Add a line of text to the import section, unless it's already there."""
@@ -1448,11 +1439,9 @@ def get_str_type_of_node(
14481439 if isinstance (rvalue , NameExpr ) and rvalue .name in ("True" , "False" ):
14491440 return "bool"
14501441 if can_infer_optional and isinstance (rvalue , NameExpr ) and rvalue .name == "None" :
1451- self .add_typing_import ("Incomplete" )
1452- return f"{ self .typing_name ('Incomplete' )} | None"
1442+ return f"{ self .add_typing_import ('Incomplete' )} | None"
14531443 if can_be_any :
1454- self .add_typing_import ("Incomplete" )
1455- return self .typing_name ("Incomplete" )
1444+ return self .add_typing_import ("Incomplete" )
14561445 else :
14571446 return ""
14581447
0 commit comments