|
47 | 47 | AnyType, |
48 | 48 | DeletedType, |
49 | 49 | Instance, |
50 | | - ParamSpecType, |
51 | 50 | ProperType, |
52 | 51 | TupleType, |
53 | 52 | Type, |
@@ -959,44 +958,44 @@ def get_dict_base_type(self, expr: Expression) -> list[Instance]: |
959 | 958 |
|
960 | 959 | This is useful for dict subclasses like SymbolTable. |
961 | 960 | """ |
962 | | - target_type = get_proper_type(self.types[expr]) |
| 961 | + return self.get_dict_base_type_from_type(self.types[expr]) |
| 962 | + |
| 963 | + def get_dict_base_type_from_type(self, target_type: Type) -> list[Instance]: |
| 964 | + target_type = get_proper_type(target_type) |
963 | 965 | if isinstance(target_type, UnionType): |
964 | | - types = [get_proper_type(item) for item in target_type.items] |
| 966 | + return [ |
| 967 | + inner |
| 968 | + for item in target_type.items |
| 969 | + for inner in self.get_dict_base_type_from_type(item) |
| 970 | + ] |
| 971 | + if isinstance(target_type, TypeVarLikeType): |
| 972 | + # Match behaviour of self.node_type |
| 973 | + # We can only reach this point if `target_type` was a TypeVar(bound=dict[...]) |
| 974 | + # or a ParamSpec. |
| 975 | + return self.get_dict_base_type_from_type(target_type.upper_bound) |
| 976 | + |
| 977 | + if isinstance(target_type, TypedDictType): |
| 978 | + target_type = target_type.fallback |
| 979 | + dict_base = next( |
| 980 | + base for base in target_type.type.mro if base.fullname == "typing.Mapping" |
| 981 | + ) |
| 982 | + elif isinstance(target_type, Instance): |
| 983 | + dict_base = next( |
| 984 | + base for base in target_type.type.mro if base.fullname == "builtins.dict" |
| 985 | + ) |
965 | 986 | else: |
966 | | - types = [target_type] |
967 | | - |
968 | | - dict_types = [] |
969 | | - for t in types: |
970 | | - if isinstance(t, TypedDictType): |
971 | | - t = t.fallback |
972 | | - dict_base = next(base for base in t.type.mro if base.fullname == "typing.Mapping") |
973 | | - else: |
974 | | - if isinstance(t, ParamSpecType): |
975 | | - # Since `ParamSpec(upper_bound=...)` is not defined yet, we know that |
976 | | - # `bound` is set to `dict[str, object]`. In any future sane implementation |
977 | | - # it still has to be exactly a `dict` as that's how kwargs work |
978 | | - # at runtime. |
979 | | - t = get_proper_type(t.upper_bound) |
980 | | - assert isinstance(t, Instance), t |
981 | | - dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict") |
982 | | - dict_types.append(map_instance_to_supertype(t, dict_base)) |
983 | | - return dict_types |
| 987 | + assert False, f"Failed to extract dict base from {target_type}" |
| 988 | + return [map_instance_to_supertype(target_type, dict_base)] |
984 | 989 |
|
985 | 990 | def get_dict_key_type(self, expr: Expression) -> RType: |
986 | 991 | dict_base_types = self.get_dict_base_type(expr) |
987 | | - if len(dict_base_types) == 1: |
988 | | - return self.type_to_rtype(dict_base_types[0].args[0]) |
989 | | - else: |
990 | | - rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types] |
991 | | - return RUnion.make_simplified_union(rtypes) |
| 992 | + rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types] |
| 993 | + return RUnion.make_simplified_union(rtypes) |
992 | 994 |
|
993 | 995 | def get_dict_value_type(self, expr: Expression) -> RType: |
994 | 996 | dict_base_types = self.get_dict_base_type(expr) |
995 | | - if len(dict_base_types) == 1: |
996 | | - return self.type_to_rtype(dict_base_types[0].args[1]) |
997 | | - else: |
998 | | - rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types] |
999 | | - return RUnion.make_simplified_union(rtypes) |
| 997 | + rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types] |
| 998 | + return RUnion.make_simplified_union(rtypes) |
1000 | 999 |
|
1001 | 1000 | def get_dict_item_type(self, expr: Expression) -> RType: |
1002 | 1001 | key_type = self.get_dict_key_type(expr) |
|
0 commit comments