|
15 | 15 |
|
16 | 16 | from abc import abstractmethod |
17 | 17 | from collections.abc import Iterable, Sequence |
18 | | -from typing import Any, Callable, Final, Generic, TypeVar, cast |
| 18 | +from typing import Any, Final, Generic, TypeVar, cast |
19 | 19 |
|
20 | 20 | from mypy_extensions import mypyc_attr, trait |
21 | 21 |
|
@@ -353,16 +353,19 @@ class TypeQuery(SyntheticTypeVisitor[T]): |
353 | 353 | # TODO: check that we don't have existing violations of this rule. |
354 | 354 | """ |
355 | 355 |
|
356 | | - def __init__(self, strategy: Callable[[list[T]], T]) -> None: |
357 | | - self.strategy = strategy |
| 356 | + def __init__(self) -> None: |
358 | 357 | # Keep track of the type aliases already visited. This is needed to avoid |
359 | 358 | # infinite recursion on types like A = Union[int, List[A]]. |
360 | | - self.seen_aliases: set[TypeAliasType] = set() |
| 359 | + self.seen_aliases: set[TypeAliasType] | None = None |
361 | 360 | # By default, we eagerly expand type aliases, and query also types in the |
362 | 361 | # alias target. In most cases this is a desired behavior, but we may want |
363 | 362 | # to skip targets in some cases (e.g. when collecting type variables). |
364 | 363 | self.skip_alias_target = False |
365 | 364 |
|
| 365 | + @abstractmethod |
| 366 | + def strategy(self, items: list[T]) -> T: |
| 367 | + raise NotImplementedError |
| 368 | + |
366 | 369 | def visit_unbound_type(self, t: UnboundType, /) -> T: |
367 | 370 | return self.query_types(t.args) |
368 | 371 |
|
@@ -440,14 +443,15 @@ def visit_placeholder_type(self, t: PlaceholderType, /) -> T: |
440 | 443 | return self.query_types(t.args) |
441 | 444 |
|
442 | 445 | def visit_type_alias_type(self, t: TypeAliasType, /) -> T: |
443 | | - # Skip type aliases already visited types to avoid infinite recursion. |
444 | | - # TODO: Ideally we should fire subvisitors here (or use caching) if we care |
445 | | - # about duplicates. |
446 | | - if t in self.seen_aliases: |
447 | | - return self.strategy([]) |
448 | | - self.seen_aliases.add(t) |
449 | 446 | if self.skip_alias_target: |
450 | 447 | return self.query_types(t.args) |
| 448 | + # Skip type aliases already visited types to avoid infinite recursion. |
| 449 | + if t.is_recursive: |
| 450 | + if self.seen_aliases is None: |
| 451 | + self.seen_aliases = set() |
| 452 | + elif t in self.seen_aliases: |
| 453 | + return self.strategy([]) |
| 454 | + self.seen_aliases.add(t) |
451 | 455 | return get_proper_type(t).accept(self) |
452 | 456 |
|
453 | 457 | def query_types(self, types: Iterable[Type]) -> T: |
@@ -580,16 +584,15 @@ def visit_placeholder_type(self, t: PlaceholderType, /) -> bool: |
580 | 584 | return self.query_types(t.args) |
581 | 585 |
|
582 | 586 | def visit_type_alias_type(self, t: TypeAliasType, /) -> bool: |
583 | | - # Skip type aliases already visited types to avoid infinite recursion. |
584 | | - # TODO: Ideally we should fire subvisitors here (or use caching) if we care |
585 | | - # about duplicates. |
586 | | - if self.seen_aliases is None: |
587 | | - self.seen_aliases = set() |
588 | | - elif t in self.seen_aliases: |
589 | | - return self.default |
590 | | - self.seen_aliases.add(t) |
591 | 587 | if self.skip_alias_target: |
592 | 588 | return self.query_types(t.args) |
| 589 | + # Skip type aliases already visited types to avoid infinite recursion. |
| 590 | + if t.is_recursive: |
| 591 | + if self.seen_aliases is None: |
| 592 | + self.seen_aliases = set() |
| 593 | + elif t in self.seen_aliases: |
| 594 | + return self.default |
| 595 | + self.seen_aliases.add(t) |
593 | 596 | return get_proper_type(t).accept(self) |
594 | 597 |
|
595 | 598 | def query_types(self, types: list[Type] | tuple[Type, ...]) -> bool: |
|
0 commit comments