Skip to content

Commit 984f7bf

Browse files
committed
Handle collections.abc.Callable
1 parent e449ad2 commit 984f7bf

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

strcs/hints.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import sys
1616
import types
1717
import typing
18-
from collections.abc import Mapping
18+
from collections.abc import Callable, Mapping
1919
from typing import (
2020
TYPE_CHECKING,
2121
ForwardRef,
@@ -110,6 +110,19 @@ def has(self, obj: T, origin: type[C]) -> TypeGuard["WithClassGetItem"]:
110110
)
111111

112112

113+
class IsCallable(Protocol):
114+
"""
115+
Used to identify objects that are unions
116+
"""
117+
118+
__args__: tuple
119+
__origin__: WithClassGetItem
120+
121+
@classmethod
122+
def has(self, obj: object) -> TypeGuard["IsCallable"]:
123+
return typing.get_origin(obj) in (Callable,)
124+
125+
113126
def resolve_type(
114127
typ: object,
115128
globalns: dict[str, object] | None = None,
@@ -140,13 +153,18 @@ def resolve_type(
140153
return typ
141154
return functools.reduce(operator.or_, resolved)
142155

156+
elif IsCallable.has(typ):
157+
resolved = tuple(resolve_type(t, globalns, localns) for t in typ.__args__)
158+
if len(resolved) == 0:
159+
return typ
160+
*args, ret = resolved
161+
return typ.__origin__.__class_getitem__((args, ret))
162+
143163
elif isinstance(origin, type) and WithClassGetItem.has(typ, origin):
144164
resolved = tuple(resolve_type(t, globalns, localns) for t in typ.__args__)
145165
if resolved == typ.__args__:
146166
return typ
147-
return typ.__origin__.__class_getitem__(
148-
tuple(resolve_type(t, globalns, localns) for t in typ.__args__)
149-
)
167+
return typ.__origin__.__class_getitem__(resolved)
150168

151169
else:
152170
return typ

0 commit comments

Comments
 (0)