52
52
53
53
54
54
@overload
55
- def expand_type (typ : CallableType , env : Mapping [TypeVarId , Type ]) -> CallableType : ...
55
+ def expand_type (
56
+ typ : CallableType , env : Mapping [TypeVarId , Type ], * , keep_none_type : bool = ...
57
+ ) -> CallableType : ...
56
58
57
59
58
60
@overload
59
- def expand_type (typ : ProperType , env : Mapping [TypeVarId , Type ]) -> ProperType : ...
61
+ def expand_type (
62
+ typ : ProperType , env : Mapping [TypeVarId , Type ], * , strict_optional : bool = ...
63
+ ) -> ProperType : ...
60
64
61
65
62
66
@overload
63
- def expand_type (typ : Type , env : Mapping [TypeVarId , Type ]) -> Type : ...
67
+ def expand_type (
68
+ typ : Type , env : Mapping [TypeVarId , Type ], * , strict_optional : bool = ...
69
+ ) -> Type : ...
64
70
65
71
66
- def expand_type (typ : Type , env : Mapping [TypeVarId , Type ]) -> Type :
72
+ def expand_type (
73
+ typ : Type , env : Mapping [TypeVarId , Type ], * , strict_optional : bool = False
74
+ ) -> Type :
67
75
"""Substitute any type variable references in a type given by a type
68
76
environment.
69
77
"""
70
- return typ .accept (ExpandTypeVisitor (env ))
78
+ return typ .accept (ExpandTypeVisitor (env , strict_optional ))
71
79
72
80
73
81
@overload
@@ -184,8 +192,9 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
184
192
185
193
variables : Mapping [TypeVarId , Type ] # TypeVar id -> TypeVar value
186
194
187
- def __init__ (self , variables : Mapping [TypeVarId , Type ]) -> None :
195
+ def __init__ (self , variables : Mapping [TypeVarId , Type ], strict_optional : bool = False ) -> None :
188
196
self .variables = variables
197
+ self .strict_optional = strict_optional
189
198
self .recursive_guard : set [Type | tuple [int , Type ]] = set ()
190
199
191
200
def visit_unbound_type (self , t : UnboundType ) -> Type :
@@ -460,7 +469,7 @@ def visit_union_type(self, t: UnionType) -> Type:
460
469
# might be subtypes of others, however calling make_simplified_union()
461
470
# can cause recursion, so we just remove strict duplicates.
462
471
simplified = UnionType .make_union (
463
- remove_trivial (flatten_nested_unions (expanded )), t .line , t .column
472
+ remove_trivial (flatten_nested_unions (expanded ), self . strict_optional ), t .line , t .column
464
473
)
465
474
# This call to get_proper_type() is unfortunate but is required to preserve
466
475
# the invariant that ProperType will stay ProperType after applying expand_type(),
@@ -508,7 +517,7 @@ def expand_self_type(var: Var, typ: Type, replacement: Type) -> Type:
508
517
return typ
509
518
510
519
511
- def remove_trivial (types : Iterable [Type ]) -> list [Type ]:
520
+ def remove_trivial (types : Iterable [Type ], strict_optional : bool = False ) -> list [Type ]:
512
521
"""Make trivial simplifications on a list of types without calling is_subtype().
513
522
514
523
This makes following simplifications:
@@ -523,7 +532,7 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]:
523
532
p_t = get_proper_type (t )
524
533
if isinstance (p_t , UninhabitedType ):
525
534
continue
526
- if isinstance (p_t , NoneType ) and not state .strict_optional :
535
+ if isinstance (p_t , NoneType ) and not state .strict_optional and not strict_optional :
527
536
removed_none = True
528
537
continue
529
538
if isinstance (p_t , Instance ) and p_t .type .fullname == "builtins.object" :
0 commit comments