Skip to content

Commit ba40d3f

Browse files
committed
Keep NoneType in Union TypeVar values
1 parent 8d2715a commit ba40d3f

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

mypy/expandtype.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,30 @@
5252

5353

5454
@overload
55-
def expand_type(typ: CallableType, env: Mapping[TypeVarId, Type]) -> CallableType: ...
55+
def expand_type(
56+
typ: CallableType, env: Mapping[TypeVarId, Type], *, keep_none_type: bool = ...
57+
) -> CallableType: ...
5658

5759

5860
@overload
59-
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType: ...
61+
def expand_type(
62+
typ: ProperType, env: Mapping[TypeVarId, Type], *, strict_optional: bool = ...
63+
) -> ProperType: ...
6064

6165

6266
@overload
63-
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: ...
67+
def expand_type(
68+
typ: Type, env: Mapping[TypeVarId, Type], *, strict_optional: bool = ...
69+
) -> Type: ...
6470

6571

66-
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
72+
def expand_type(
73+
typ: Type, env: Mapping[TypeVarId, Type], *, strict_optional: bool = False
74+
) -> Type:
6775
"""Substitute any type variable references in a type given by a type
6876
environment.
6977
"""
70-
return typ.accept(ExpandTypeVisitor(env))
78+
return typ.accept(ExpandTypeVisitor(env, strict_optional))
7179

7280

7381
@overload
@@ -184,8 +192,9 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
184192

185193
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value
186194

187-
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
195+
def __init__(self, variables: Mapping[TypeVarId, Type], strict_optional: bool = False) -> None:
188196
self.variables = variables
197+
self.strict_optional = strict_optional
189198
self.recursive_guard: set[Type | tuple[int, Type]] = set()
190199

191200
def visit_unbound_type(self, t: UnboundType) -> Type:
@@ -460,7 +469,7 @@ def visit_union_type(self, t: UnionType) -> Type:
460469
# might be subtypes of others, however calling make_simplified_union()
461470
# can cause recursion, so we just remove strict duplicates.
462471
simplified = UnionType.make_union(
463-
remove_trivial(flatten_nested_unions(expanded)), t.line, t.column
472+
remove_trivial(flatten_nested_unions(expanded), self.strict_optional), t.line, t.column
464473
)
465474
# This call to get_proper_type() is unfortunate but is required to preserve
466475
# the invariant that ProperType will stay ProperType after applying expand_type(),
@@ -508,7 +517,7 @@ def expand_self_type(var: Var, typ: Type, replacement: Type) -> Type:
508517
return typ
509518

510519

511-
def remove_trivial(types: Iterable[Type]) -> list[Type]:
520+
def remove_trivial(types: Iterable[Type], strict_optional: bool = False) -> list[Type]:
512521
"""Make trivial simplifications on a list of types without calling is_subtype().
513522
514523
This makes following simplifications:
@@ -523,7 +532,7 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]:
523532
p_t = get_proper_type(t)
524533
if isinstance(p_t, UninhabitedType):
525534
continue
526-
if isinstance(p_t, NoneType) and not state.strict_optional:
535+
if isinstance(p_t, NoneType) and not state.strict_optional and not strict_optional:
527536
removed_none = True
528537
continue
529538
if isinstance(p_t, Instance) and p_t.type.fullname == "builtins.object":

mypy/typeanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1894,7 +1894,7 @@ def fix_instance(
18941894
t.args = tuple(args)
18951895
fix_type_var_tuple_argument(t)
18961896
if not t.type.has_type_var_tuple_type:
1897-
fixed = expand_type(t, env)
1897+
fixed = expand_type(t, env, strict_optional=True)
18981898
assert isinstance(fixed, Instance)
18991899
t.args = fixed.args
19001900

0 commit comments

Comments
 (0)