Skip to content

Commit e88c8f9

Browse files
Fix/callable class generator not handled (#108)
* fix: handle callable class instance generators * refactor: rewrite is coro check function to have the same style with other check functions * test: add tests for class callable generators to test task * docs: update path in docstring * refactor: remove unused code in tests * test: add tests for class callable generators to test execute * version: bump to new 0.77.0 version * refactor: remove obsolete sync param in test concurrency async case --------- Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
1 parent 4d94386 commit e88c8f9

File tree

4 files changed

+66
-49
lines changed

4 files changed

+66
-49
lines changed

di/_utils/inspect.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,27 @@ def unwrap_callable(call: Any) -> Any:
3030
def is_coroutine_callable(call: Any) -> bool:
3131
if inspect.isclass(call):
3232
return False
33-
call = unwrap_callable(call)
34-
if inspect.iscoroutinefunction(call):
33+
unwrapped_call = unwrap_callable(call)
34+
if inspect.iscoroutinefunction(unwrapped_call):
3535
return True
36-
# not a class but has a __call__, so maybe a callable class instance
37-
return inspect.iscoroutinefunction(getattr(call, "__call__"))
36+
dunder_call = getattr(unwrapped_call, "__call__", None)
37+
return inspect.iscoroutinefunction(dunder_call)
3838

3939

4040
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
41-
return inspect.isasyncgenfunction(unwrap_callable(call))
41+
unwrapped_call = unwrap_callable(call)
42+
if inspect.isasyncgenfunction(unwrapped_call):
43+
return True
44+
dunder_call = getattr(unwrapped_call, "__call__", None)
45+
return inspect.isasyncgenfunction(dunder_call)
4246

4347

4448
def is_gen_callable(call: Any) -> bool:
45-
return inspect.isgeneratorfunction(unwrap_callable(call))
49+
unwrapped_call = unwrap_callable(call)
50+
if inspect.isgeneratorfunction(unwrapped_call):
51+
return True
52+
dunder_call = getattr(unwrapped_call, "__call__", None)
53+
return inspect.isgeneratorfunction(dunder_call)
4654

4755

4856
def get_annotations(call: Callable[..., Any]) -> Dict[str, Any]:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "di"
3-
version = "0.76.0"
3+
version = "0.77.0"
44
description = "Dependency injection toolkit"
55
authors = ["Adrian Garcia Badaracco <adrian@adriangb.com>"]
66
readme = "README.md"

tests/test_execute.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -92,32 +92,6 @@ def test_execute():
9292
assert res.three.zero is res.zero
9393

9494

95-
def sync_callable_func() -> int:
96-
return 1
97-
98-
99-
async def async_callable_func() -> int:
100-
return 1
101-
102-
103-
def sync_gen_func() -> Generator[int, None, None]:
104-
yield 1
105-
106-
107-
async def async_gen_func() -> AsyncGenerator[int, None]:
108-
yield 1
109-
110-
111-
class SyncCallableCls:
112-
def __call__(self) -> int:
113-
return 1
114-
115-
116-
class AsyncCallableCls:
117-
async def __call__(self) -> int:
118-
return 1
119-
120-
12195
@dataclass
12296
class Synchronizer:
12397
started: List[anyio.Event]
@@ -164,15 +138,30 @@ async def __call__(self, synchronizer: Synchronizer) -> None:
164138
await async_callable_func_slow(synchronizer)
165139

166140

141+
class SyncGenCallableClsSlow:
142+
@as_async
143+
def __call__(self, synchronizer: Synchronizer) -> Generator[None, None, None]:
144+
_sync_callable_func_slow(synchronizer)
145+
yield None
146+
147+
148+
class AsyncGenCallableClsSlow:
149+
async def __call__(self, synchronizer: Synchronizer) -> AsyncGenerator[None, None]:
150+
await async_callable_func_slow(synchronizer)
151+
yield None
152+
153+
167154
@pytest.mark.parametrize(
168-
"dep1,sync1",
155+
"dep1",
169156
[
170-
(sync_callable_func_slow, True),
171-
(async_callable_func_slow, False),
172-
(sync_gen_func_slow, True),
173-
(async_gen_func_slow, False),
174-
(SyncCallableClsSlow(), True),
175-
(AsyncCallableClsSlow(), False),
157+
sync_callable_func_slow,
158+
async_callable_func_slow,
159+
sync_gen_func_slow,
160+
async_gen_func_slow,
161+
SyncCallableClsSlow(),
162+
AsyncCallableClsSlow(),
163+
SyncGenCallableClsSlow(),
164+
AsyncGenCallableClsSlow(),
176165
],
177166
ids=[
178167
"sync_callable_func",
@@ -181,17 +170,21 @@ async def __call__(self, synchronizer: Synchronizer) -> None:
181170
"async_gen_func",
182171
"SyncCallableCls",
183172
"AsyncCallableCls",
173+
"SyncGenCallableCls",
174+
"AsyncGenCallableCls",
184175
],
185176
)
186177
@pytest.mark.parametrize(
187-
"dep2,sync2",
178+
"dep2",
188179
[
189-
(sync_callable_func_slow, True),
190-
(async_callable_func_slow, False),
191-
(sync_gen_func_slow, True),
192-
(async_gen_func_slow, False),
193-
(SyncCallableClsSlow(), True),
194-
(AsyncCallableClsSlow(), False),
180+
sync_callable_func_slow,
181+
async_callable_func_slow,
182+
sync_gen_func_slow,
183+
async_gen_func_slow,
184+
SyncCallableClsSlow(),
185+
AsyncCallableClsSlow(),
186+
SyncGenCallableClsSlow(),
187+
AsyncGenCallableClsSlow(),
195188
],
196189
ids=[
197190
"sync_callable_func",
@@ -200,10 +193,12 @@ async def __call__(self, synchronizer: Synchronizer) -> None:
200193
"async_gen_func",
201194
"SyncCallableCls",
202195
"AsyncCallableCls",
196+
"SyncGenCallableCls",
197+
"AsyncGenCallableCls",
203198
],
204199
)
205200
@pytest.mark.anyio
206-
async def test_concurrency_async(dep1: Any, sync1: bool, dep2: Any, sync2: bool):
201+
async def test_concurrency_async(dep1: Any, dep2: Any):
207202
container = Container()
208203

209204
synchronizer = Synchronizer([anyio.Event(), anyio.Event()], anyio.Event())

tests/test_task.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Black box tests that check the high level API but in practice are written to fully
2-
test all of the execution paths in di/_utils/task.py
2+
test all of the execution paths in di/_task.py
33
"""
44
import functools
55
from typing import Any, AsyncGenerator, Callable, Generator
@@ -37,6 +37,16 @@ async def __call__(self) -> int:
3737
return 1
3838

3939

40+
class SyncGenCallableCls:
41+
def __call__(self) -> Generator[int, None, None]:
42+
yield 1
43+
44+
45+
class AsyncGenCallableCls:
46+
async def __call__(self) -> AsyncGenerator[int, None]:
47+
yield 1
48+
49+
4050
def no_wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
4151
return func
4252

@@ -62,6 +72,8 @@ def wrapper(*args, **kwargs): # type: ignore
6272
async_gen_func,
6373
SyncCallableCls(),
6474
AsyncCallableCls(),
75+
SyncGenCallableCls(),
76+
AsyncGenCallableCls(),
6577
],
6678
ids=[
6779
"sync_callable_func",
@@ -70,6 +82,8 @@ def wrapper(*args, **kwargs): # type: ignore
7082
"async_gen_func",
7183
"SyncCallableCls",
7284
"AsyncCallableCls",
85+
"SyncGenCallableCls",
86+
"AsyncGenCallableCls",
7387
],
7488
)
7589
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)