|
50 | 50 |
|
51 | 51 | @overload
|
52 | 52 | def expand_type(
|
53 |
| - typ: ProperType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ... |
| 53 | + typ: ProperType, |
| 54 | + env: Mapping[TypeVarId, Type], |
| 55 | + allow_erased_callables: bool = ..., |
| 56 | + *, |
| 57 | + keep_none_type: bool = ..., |
54 | 58 | ) -> ProperType:
|
55 | 59 | ...
|
56 | 60 |
|
57 | 61 |
|
58 | 62 | @overload
|
59 | 63 | def expand_type(
|
60 |
| - typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ... |
| 64 | + typ: Type, |
| 65 | + env: Mapping[TypeVarId, Type], |
| 66 | + allow_erased_callables: bool = ..., |
| 67 | + *, |
| 68 | + keep_none_type: bool = ..., |
61 | 69 | ) -> Type:
|
62 | 70 | ...
|
63 | 71 |
|
64 | 72 |
|
65 | 73 | def expand_type(
|
66 |
| - typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = False |
| 74 | + typ: Type, |
| 75 | + env: Mapping[TypeVarId, Type], |
| 76 | + allow_erased_callables: bool = False, |
| 77 | + *, |
| 78 | + keep_none_type: bool = False, |
67 | 79 | ) -> Type:
|
68 | 80 | """Substitute any type variable references in a type given by a type
|
69 | 81 | environment.
|
70 | 82 | """
|
71 |
| - return typ.accept(ExpandTypeVisitor(env, allow_erased_callables)) |
| 83 | + return typ.accept(ExpandTypeVisitor(env, allow_erased_callables, keep_none_type)) |
72 | 84 |
|
73 | 85 |
|
74 | 86 | @overload
|
@@ -183,10 +195,14 @@ class ExpandTypeVisitor(TypeVisitor[Type]):
|
183 | 195 | variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value
|
184 | 196 |
|
185 | 197 | def __init__(
|
186 |
| - self, variables: Mapping[TypeVarId, Type], allow_erased_callables: bool = False |
| 198 | + self, |
| 199 | + variables: Mapping[TypeVarId, Type], |
| 200 | + allow_erased_callables: bool = False, |
| 201 | + keep_none_type: bool = False, |
187 | 202 | ) -> None:
|
188 | 203 | self.variables = variables
|
189 | 204 | self.allow_erased_callables = allow_erased_callables
|
| 205 | + self.keep_none_type = keep_none_type |
190 | 206 | self.recursive_guard: set[Type | tuple[int, Type]] = set()
|
191 | 207 |
|
192 | 208 | def visit_unbound_type(self, t: UnboundType) -> Type:
|
@@ -470,7 +486,7 @@ def visit_union_type(self, t: UnionType) -> Type:
|
470 | 486 | # might be subtypes of others, however calling make_simplified_union()
|
471 | 487 | # can cause recursion, so we just remove strict duplicates.
|
472 | 488 | return UnionType.make_union(
|
473 |
| - remove_trivial(flatten_nested_unions(expanded)), t.line, t.column |
| 489 | + remove_trivial(flatten_nested_unions(expanded), self.keep_none_type), t.line, t.column |
474 | 490 | )
|
475 | 491 |
|
476 | 492 | def visit_partial_type(self, t: PartialType) -> Type:
|
|
0 commit comments