|
15 | 15 |
|
16 | 16 | from abc import abstractmethod
|
17 | 17 | from collections.abc import Iterable, Sequence
|
18 |
| -from typing import Any, Callable, Final, Generic, TypeVar, cast |
| 18 | +from typing import Any, Final, Generic, TypeVar, cast |
19 | 19 |
|
20 | 20 | from mypy_extensions import mypyc_attr, trait
|
21 | 21 |
|
@@ -353,16 +353,19 @@ class TypeQuery(SyntheticTypeVisitor[T]):
|
353 | 353 | # TODO: check that we don't have existing violations of this rule.
|
354 | 354 | """
|
355 | 355 |
|
356 |
| - def __init__(self, strategy: Callable[[list[T]], T]) -> None: |
357 |
| - self.strategy = strategy |
| 356 | + def __init__(self) -> None: |
358 | 357 | # Keep track of the type aliases already visited. This is needed to avoid
|
359 | 358 | # infinite recursion on types like A = Union[int, List[A]].
|
360 |
| - self.seen_aliases: set[TypeAliasType] = set() |
| 359 | + self.seen_aliases: set[TypeAliasType] | None = None |
361 | 360 | # By default, we eagerly expand type aliases, and query also types in the
|
362 | 361 | # alias target. In most cases this is a desired behavior, but we may want
|
363 | 362 | # to skip targets in some cases (e.g. when collecting type variables).
|
364 | 363 | self.skip_alias_target = False
|
365 | 364 |
|
| 365 | + @abstractmethod |
| 366 | + def strategy(self, items: list[T]) -> T: |
| 367 | + raise NotImplementedError |
| 368 | + |
366 | 369 | def visit_unbound_type(self, t: UnboundType, /) -> T:
|
367 | 370 | return self.query_types(t.args)
|
368 | 371 |
|
@@ -440,14 +443,15 @@ def visit_placeholder_type(self, t: PlaceholderType, /) -> T:
|
440 | 443 | return self.query_types(t.args)
|
441 | 444 |
|
442 | 445 | def visit_type_alias_type(self, t: TypeAliasType, /) -> T:
|
443 |
| - # Skip type aliases already visited types to avoid infinite recursion. |
444 |
| - # TODO: Ideally we should fire subvisitors here (or use caching) if we care |
445 |
| - # about duplicates. |
446 |
| - if t in self.seen_aliases: |
447 |
| - return self.strategy([]) |
448 |
| - self.seen_aliases.add(t) |
449 | 446 | if self.skip_alias_target:
|
450 | 447 | return self.query_types(t.args)
|
| 448 | + # Skip type aliases already visited types to avoid infinite recursion |
| 449 | + # (also use this as a simple-minded cache). |
| 450 | + if self.seen_aliases is None: |
| 451 | + self.seen_aliases = set() |
| 452 | + elif t in self.seen_aliases: |
| 453 | + return self.strategy([]) |
| 454 | + self.seen_aliases.add(t) |
451 | 455 | return get_proper_type(t).accept(self)
|
452 | 456 |
|
453 | 457 | def query_types(self, types: Iterable[Type]) -> T:
|
@@ -580,16 +584,15 @@ def visit_placeholder_type(self, t: PlaceholderType, /) -> bool:
|
580 | 584 | return self.query_types(t.args)
|
581 | 585 |
|
582 | 586 | def visit_type_alias_type(self, t: TypeAliasType, /) -> bool:
|
583 |
| - # Skip type aliases already visited types to avoid infinite recursion. |
584 |
| - # TODO: Ideally we should fire subvisitors here (or use caching) if we care |
585 |
| - # about duplicates. |
| 587 | + if self.skip_alias_target: |
| 588 | + return self.query_types(t.args) |
| 589 | + # Skip type aliases already visited types to avoid infinite recursion |
| 590 | + # (also use this as a simple-minded cache). |
586 | 591 | if self.seen_aliases is None:
|
587 | 592 | self.seen_aliases = set()
|
588 | 593 | elif t in self.seen_aliases:
|
589 | 594 | return self.default
|
590 | 595 | self.seen_aliases.add(t)
|
591 |
| - if self.skip_alias_target: |
592 |
| - return self.query_types(t.args) |
593 | 596 | return get_proper_type(t).accept(self)
|
594 | 597 |
|
595 | 598 | def query_types(self, types: list[Type] | tuple[Type, ...]) -> bool:
|
|
0 commit comments