Skip to content

Commit d355857

Browse files
feature: add ClassTypeHint.recurse_bases()
1 parent f550ae8 commit d355857

File tree

3 files changed

+91
-3
lines changed

3 files changed

+91
-3
lines changed

.changelog/_unreleased.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,9 @@ id = "a5578e7f-32ca-442e-b52a-e6a16c573c35"
33
type = "fix"
44
description = "`ClassTypeHint.bases` now falls back to the types actual base classes if it is not a generic type"
55
author = "@NiklasRosenstein"
6+
7+
[[entries]]
8+
id = "5883174c-0b7b-4be0-b28e-b7c69922f551"
9+
type = "feature"
10+
description = "add `ClassTypeHint.recurse_bases()`"
11+
author = "@NiklasRosenstein"

src/typeapi/typehint.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
import abc
22
import sys
3-
from collections import ChainMap
3+
from collections import ChainMap, deque
44
from types import ModuleType
5-
from typing import Any, Dict, Generic, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar, Union, cast, overload
5+
from typing import (
6+
Any,
7+
Dict,
8+
Generator,
9+
Generic,
10+
Iterator,
11+
List,
12+
Mapping,
13+
MutableMapping,
14+
Tuple,
15+
TypeVar,
16+
Union,
17+
cast,
18+
overload,
19+
)
620

7-
from typing_extensions import Annotated
21+
from typing_extensions import Annotated, Literal
822

923
from .utils import (
1024
ForwardRef,
@@ -283,6 +297,38 @@ def get_parameter_map(self) -> Dict[Any, Any]:
283297
# use self.parameters.
284298
return dict(zip(TypeHint(self.type).parameters, self.args))
285299

300+
def recurse_bases(
301+
self, order: Literal["dfs", "bfs"] = "bfs"
302+
) -> Generator["ClassTypeHint", Union[Literal["skip"], None], None]:
303+
"""
304+
Iterate over all base classes of this type hint, and continues recursively. The iteration order is
305+
determined by the *order* parameter, which can be either depth-first or breadh-first. If the generator
306+
receives the string `"skip"` from the caller, it will skip the bases of the last yielded type.
307+
"""
308+
309+
# Find the item type in the base classes of the collection type.
310+
bases = deque([self])
311+
312+
while bases:
313+
current = bases.popleft()
314+
if not isinstance(current, ClassTypeHint):
315+
raise RuntimeError(
316+
f"Expected to find a ClassTypeHint in the base classes of {self!r}, found {current!r} instead."
317+
)
318+
319+
response = yield current
320+
if response == "skip":
321+
continue
322+
323+
current_bases = cast(List[ClassTypeHint], [TypeHint(x, current.type).evaluate() for x in current.bases])
324+
325+
if order == "bfs":
326+
bases.extend(current_bases)
327+
elif order == "dfs":
328+
bases.extendleft(reversed(current_bases))
329+
else:
330+
raise ValueError(f"Invalid order {order!r}")
331+
286332

287333
class UnionTypeHint(TypeHint):
288334
def has_none_type(self) -> bool:

src/typeapi/typehint_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,39 @@ def test__TypeHint__repeated() -> None:
497497
assert hint.args == (int,)
498498
assert hint.parameters == ()
499499
assert hint.repeated
500+
501+
502+
def test__ClassTypeHint__iter_all_bases() -> None:
503+
class A:
504+
pass
505+
506+
class B(A, int):
507+
pass
508+
509+
class C(B, Generic[T]):
510+
pass
511+
512+
hint = TypeHint(C)
513+
assert isinstance(hint, ClassTypeHint)
514+
assert hint.type is C
515+
assert hint.bases == (B, Generic[T])
516+
assert list(hint.recurse_bases("bfs")) == [
517+
TypeHint(C),
518+
TypeHint(B),
519+
TypeHint(Generic[T]),
520+
TypeHint(A),
521+
TypeHint(int),
522+
TypeHint(object),
523+
TypeHint(object),
524+
TypeHint(object),
525+
]
526+
assert list(hint.recurse_bases("dfs")) == [
527+
TypeHint(C),
528+
TypeHint(B),
529+
TypeHint(A),
530+
TypeHint(object),
531+
TypeHint(int),
532+
TypeHint(object),
533+
TypeHint(Generic[T]),
534+
TypeHint(object),
535+
]

0 commit comments

Comments
 (0)