|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | from collections import defaultdict |
6 | | -from typing import Final, NamedTuple |
| 6 | +from itertools import zip_longest |
| 7 | +from typing import Final, Literal, NamedTuple |
7 | 8 |
|
8 | 9 | from mypy import message_registry |
9 | 10 | from mypy.checker_shared import TypeCheckerSharedApi, TypeRange |
@@ -245,34 +246,79 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: |
245 | 246 | # |
246 | 247 | # get inner types of original type |
247 | 248 | # |
248 | | - unpack_index = None |
249 | | - if isinstance(current_type, TupleType): |
250 | | - inner_types = current_type.items |
251 | | - unpack_index = find_unpack_in_list(inner_types) |
252 | | - if unpack_index is None: |
253 | | - size_diff = len(inner_types) - required_patterns |
254 | | - if size_diff < 0: |
255 | | - return self.early_non_match() |
256 | | - elif size_diff > 0 and star_position is None: |
257 | | - return self.early_non_match() |
| 249 | + # 1. Go through all possible types and filter to only those which are sequences that could match that number of items |
| 250 | + # 2. If there are multiple tuples left with unpacks, then use the fallback logic where we union all items types |
| 251 | + # 3. Otherwise, take the product of the item types so that each index can have a unique type |
| 252 | + |
| 253 | + # state of matching |
| 254 | + state: ( |
| 255 | + # Start in the state where we not encountered an unpack. |
| 256 | + # a list of all the possible types that could match the sequence. If it's a tuple, then store one for each index |
| 257 | + tuple[Literal["NO_UNPACK"], list[list[Type]]] | |
| 258 | + # If we encounter a single tuple with an unpack, store the type, the unpack index, and the index in the union type |
| 259 | + tuple[Literal["UNPACK"], TupleType, int, int] | |
| 260 | + # If we have encountered a tuple with an unpack plus any other types, then store a list of them. For any tuples |
| 261 | + # without unpacks, store them as a list of their items. |
| 262 | + tuple[Literal["MULTI_UNPACK"], list[list[Type]]] |
| 263 | + ) = ("NO_UNPACK", []) |
| 264 | + for i, t in enumerate(current_type.items) if isinstance(current_type, UnionType) else ((0, current_type),): |
| 265 | + t = get_proper_type(t) |
| 266 | + n_patterns = len(o.patterns) |
| 267 | + if isinstance(t, TupleType): |
| 268 | + t_items = t.items |
| 269 | + unpack_index = find_unpack_in_list(t_items) |
| 270 | + if unpack_index is None: |
| 271 | + size_diff = len(t_items) - required_patterns |
| 272 | + if size_diff < 0: |
| 273 | + continue |
| 274 | + elif size_diff > 0 and star_position is None: |
| 275 | + continue |
| 276 | + inner_t = list(t_items) |
| 277 | + else: |
| 278 | + normalized_inner_types = [] |
| 279 | + for it in t_items: |
| 280 | + # Unfortunately, it is not possible to "split" the TypeVarTuple |
| 281 | + # into individual items, so we just use its upper bound for the whole |
| 282 | + # analysis instead. |
| 283 | + if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType): |
| 284 | + it = UnpackType(it.type.upper_bound) |
| 285 | + normalized_inner_types.append(it) |
| 286 | + t_items = normalized_inner_types |
| 287 | + t = t.copy_modified(items=normalized_inner_types) |
| 288 | + if len(t_items) - 1 > required_patterns and star_position is None: |
| 289 | + continue |
| 290 | + if state[0] == "NO_UNPACK" and not state[1]: |
| 291 | + state = ("UNPACK", t, unpack_index, i) |
| 292 | + continue |
| 293 | + inner_t = [self.chk.iterable_item_type(tuple_fallback(t), o)] * n_patterns |
| 294 | + elif isinstance(t, AnyType): |
| 295 | + inner_t = [AnyType(TypeOfAny.from_another_any, t)] * n_patterns |
| 296 | + elif self.chk.type_is_iterable(t) and isinstance(t, Instance): |
| 297 | + inner_t = [self.chk.iterable_item_type(t, o)] * n_patterns |
258 | 298 | else: |
259 | | - normalized_inner_types = [] |
260 | | - for it in inner_types: |
261 | | - # Unfortunately, it is not possible to "split" the TypeVarTuple |
262 | | - # into individual items, so we just use its upper bound for the whole |
263 | | - # analysis instead. |
264 | | - if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType): |
265 | | - it = UnpackType(it.type.upper_bound) |
266 | | - normalized_inner_types.append(it) |
267 | | - inner_types = normalized_inner_types |
268 | | - current_type = current_type.copy_modified(items=normalized_inner_types) |
269 | | - if len(inner_types) - 1 > required_patterns and star_position is None: |
270 | | - return self.early_non_match() |
| 299 | + continue |
| 300 | + # if we previously encountered an unpack, then change the state. |
| 301 | + if state[0] == "UNPACK": |
| 302 | + # if we already unpacked something, change this |
| 303 | + state = ("MULTI_UNPACK", [[self.chk.iterable_item_type(tuple_fallback(state[1]), o)] * n_patterns]) |
| 304 | + assert state[0] != "UNPACK" # for type checker |
| 305 | + state[1].append(inner_t) |
| 306 | + |
| 307 | + if state[0] == "UNPACK": |
| 308 | + _, update_tuple_type, unpack_index, union_index = state |
| 309 | + inner_types = update_tuple_type.items |
| 310 | + if isinstance(current_type, UnionType): |
| 311 | + union_items = list(current_type.items) |
| 312 | + union_items[union_index] = update_tuple_type |
| 313 | + current_type = current_type.copy_modified(items=union_items) |
| 314 | + else: |
| 315 | + assert unpack_index == 0, "Unpack index should be 0 for non-union types" |
| 316 | + current_type = update_tuple_type |
271 | 317 | else: |
272 | | - inner_type = self.get_sequence_type(current_type, o) |
273 | | - if inner_type is None: |
274 | | - inner_type = self.chk.named_type("builtins.object") |
275 | | - inner_types = [inner_type] * len(o.patterns) |
| 318 | + unpack_index = None |
| 319 | + if not state[1]: |
| 320 | + return self.early_non_match() |
| 321 | + inner_types = [make_simplified_union(x) for x in zip_longest(*state[1], fillvalue=UninhabitedType())] |
276 | 322 |
|
277 | 323 | # |
278 | 324 | # match inner patterns |
@@ -351,25 +397,6 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: |
351 | 397 | new_type = self.narrow_sequence_child(current_type, new_inner_type, o) |
352 | 398 | return PatternType(new_type, rest_type, captures) |
353 | 399 |
|
354 | | - def get_sequence_type(self, t: Type, context: Context) -> Type | None: |
355 | | - t = get_proper_type(t) |
356 | | - if isinstance(t, AnyType): |
357 | | - return AnyType(TypeOfAny.from_another_any, t) |
358 | | - if isinstance(t, UnionType): |
359 | | - items = [self.get_sequence_type(item, context) for item in t.items] |
360 | | - not_none_items = [item for item in items if item is not None] |
361 | | - if not_none_items: |
362 | | - return make_simplified_union(not_none_items) |
363 | | - else: |
364 | | - return None |
365 | | - |
366 | | - if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)): |
367 | | - if isinstance(t, TupleType): |
368 | | - t = tuple_fallback(t) |
369 | | - return self.chk.iterable_item_type(t, context) |
370 | | - else: |
371 | | - return None |
372 | | - |
373 | 400 | def contract_starred_pattern_types( |
374 | 401 | self, types: list[Type], star_pos: int | None, num_patterns: int |
375 | 402 | ) -> list[Type]: |
|
0 commit comments