|
15 | 15 | import sys |
16 | 16 | import types |
17 | 17 | import typing |
18 | | -from collections.abc import Mapping |
| 18 | +from collections.abc import Callable, Mapping |
19 | 19 | from typing import ( |
20 | 20 | TYPE_CHECKING, |
21 | 21 | ForwardRef, |
@@ -110,6 +110,19 @@ def has(self, obj: T, origin: type[C]) -> TypeGuard["WithClassGetItem"]: |
110 | 110 | ) |
111 | 111 |
|
112 | 112 |
|
| 113 | +class IsCallable(Protocol): |
| 114 | + """ |
| 115 | + Used to identify objects that are unions |
| 116 | + """ |
| 117 | + |
| 118 | + __args__: tuple |
| 119 | + __origin__: WithClassGetItem |
| 120 | + |
| 121 | + @classmethod |
| 122 | + def has(self, obj: object) -> TypeGuard["IsCallable"]: |
| 123 | + return typing.get_origin(obj) in (Callable,) |
| 124 | + |
| 125 | + |
113 | 126 | def resolve_type( |
114 | 127 | typ: object, |
115 | 128 | globalns: dict[str, object] | None = None, |
@@ -140,13 +153,18 @@ def resolve_type( |
140 | 153 | return typ |
141 | 154 | return functools.reduce(operator.or_, resolved) |
142 | 155 |
|
| 156 | + elif IsCallable.has(typ): |
| 157 | + resolved = tuple(resolve_type(t, globalns, localns) for t in typ.__args__) |
| 158 | + if len(resolved) == 0: |
| 159 | + return typ |
| 160 | + *args, ret = resolved |
| 161 | + return typ.__origin__.__class_getitem__((args, ret)) |
| 162 | + |
143 | 163 | elif isinstance(origin, type) and WithClassGetItem.has(typ, origin): |
144 | 164 | resolved = tuple(resolve_type(t, globalns, localns) for t in typ.__args__) |
145 | 165 | if resolved == typ.__args__: |
146 | 166 | return typ |
147 | | - return typ.__origin__.__class_getitem__( |
148 | | - tuple(resolve_type(t, globalns, localns) for t in typ.__args__) |
149 | | - ) |
| 167 | + return typ.__origin__.__class_getitem__(resolved) |
150 | 168 |
|
151 | 169 | else: |
152 | 170 | return typ |
|
0 commit comments