Skip to content

Commit 6f768ae

Browse files
committed
Support nested variants, add a test
1 parent 6c55624 commit 6f768ae

File tree

2 files changed

+664
-32
lines changed

2 files changed

+664
-32
lines changed

mypyc/irbuild/builder.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
AnyType,
4848
DeletedType,
4949
Instance,
50-
ParamSpecType,
5150
ProperType,
5251
TupleType,
5352
Type,
@@ -959,44 +958,44 @@ def get_dict_base_type(self, expr: Expression) -> list[Instance]:
959958
960959
This is useful for dict subclasses like SymbolTable.
961960
"""
962-
target_type = get_proper_type(self.types[expr])
961+
return self.get_dict_base_type_from_type(self.types[expr])
962+
963+
def get_dict_base_type_from_type(self, target_type: Type) -> list[Instance]:
964+
target_type = get_proper_type(target_type)
963965
if isinstance(target_type, UnionType):
964-
types = [get_proper_type(item) for item in target_type.items]
966+
return [
967+
inner
968+
for item in target_type.items
969+
for inner in self.get_dict_base_type_from_type(item)
970+
]
971+
if isinstance(target_type, TypeVarLikeType):
972+
# Match behaviour of self.node_type
973+
# We can only reach this point if `target_type` was a TypeVar(bound=dict[...])
974+
# or a ParamSpec.
975+
return self.get_dict_base_type_from_type(target_type.upper_bound)
976+
977+
if isinstance(target_type, TypedDictType):
978+
target_type = target_type.fallback
979+
dict_base = next(
980+
base for base in target_type.type.mro if base.fullname == "typing.Mapping"
981+
)
982+
elif isinstance(target_type, Instance):
983+
dict_base = next(
984+
base for base in target_type.type.mro if base.fullname == "builtins.dict"
985+
)
965986
else:
966-
types = [target_type]
967-
968-
dict_types = []
969-
for t in types:
970-
if isinstance(t, TypedDictType):
971-
t = t.fallback
972-
dict_base = next(base for base in t.type.mro if base.fullname == "typing.Mapping")
973-
else:
974-
if isinstance(t, ParamSpecType):
975-
# Since `ParamSpec(upper_bound=...)` is not defined yet, we know that
976-
# `bound` is set to `dict[str, object]`. In any future sane implementation
977-
# it still has to be exactly a `dict` as that's how kwargs work
978-
# at runtime.
979-
t = get_proper_type(t.upper_bound)
980-
assert isinstance(t, Instance), t
981-
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
982-
dict_types.append(map_instance_to_supertype(t, dict_base))
983-
return dict_types
987+
assert False, f"Failed to extract dict base from {target_type}"
988+
return [map_instance_to_supertype(target_type, dict_base)]
984989

985990
def get_dict_key_type(self, expr: Expression) -> RType:
986991
dict_base_types = self.get_dict_base_type(expr)
987-
if len(dict_base_types) == 1:
988-
return self.type_to_rtype(dict_base_types[0].args[0])
989-
else:
990-
rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types]
991-
return RUnion.make_simplified_union(rtypes)
992+
rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types]
993+
return RUnion.make_simplified_union(rtypes)
992994

993995
def get_dict_value_type(self, expr: Expression) -> RType:
994996
dict_base_types = self.get_dict_base_type(expr)
995-
if len(dict_base_types) == 1:
996-
return self.type_to_rtype(dict_base_types[0].args[1])
997-
else:
998-
rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types]
999-
return RUnion.make_simplified_union(rtypes)
997+
rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types]
998+
return RUnion.make_simplified_union(rtypes)
1000999

10011000
def get_dict_item_type(self, expr: Expression) -> RType:
10021001
key_type = self.get_dict_key_type(expr)

0 commit comments

Comments
 (0)