Skip to content

Commit c22baa5

Browse files
committed
Support custom async iterables
Add is_async_iterable parameter for custom iterables (#258)
1 parent 37b3d06 commit c22baa5

File tree

7 files changed

+193
-12
lines changed

7 files changed

+193
-12
lines changed

docs/modules/pyutils.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ PyUtils
1616
.. autofunction:: identity_func
1717
.. autofunction:: inspect
1818
.. autofunction:: is_awaitable
19+
.. autofunction:: is_async_iterable
1920
.. autofunction:: is_collection
2021
.. autofunction:: is_iterable
2122
.. autofunction:: natural_comparison_key

src/graphql/execution/execute.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@
4747
inspect,
4848
is_iterable,
4949
)
50-
from ..pyutils.is_awaitable import is_awaitable as default_is_awaitable
50+
from ..pyutils.is_awaitable import (
51+
is_async_iterable as default_is_async_iterable,
52+
)
53+
from ..pyutils.is_awaitable import (
54+
is_awaitable as default_is_awaitable,
55+
)
5156
from ..type import (
5257
GraphQLAbstractType,
5358
GraphQLField,
@@ -201,6 +206,9 @@ class ExecutionContext(IncrementalPublisherContext):
201206
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] = staticmethod(
202207
default_is_awaitable # type: ignore
203208
)
209+
is_async_iterable: Callable[[Any], TypeGuard[AsyncIterable]] = staticmethod(
210+
default_is_async_iterable # type: ignore
211+
)
204212

205213
def __init__(
206214
self,
@@ -216,6 +224,7 @@ def __init__(
216224
enable_early_execution: bool = False,
217225
middleware_manager: MiddlewareManager | None = None,
218226
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None = None,
227+
is_async_iterable: Callable[[Any], TypeGuard[AsyncIterable]] | None = None,
219228
) -> None:
220229
self.schema = schema
221230
self.fragments = fragments
@@ -229,6 +238,7 @@ def __init__(
229238
self.enable_early_execution = enable_early_execution
230239
self.middleware_manager = middleware_manager
231240
self.is_awaitable = is_awaitable or default_is_awaitable
241+
self.is_async_iterable = is_async_iterable or default_is_async_iterable
232242
self.errors = None
233243
self.cancellable_streams = None
234244
self._canceled_iterators: set[AsyncIterator] = set()
@@ -251,6 +261,7 @@ def build(
251261
enable_early_execution: bool = False,
252262
middleware: Middleware | None = None,
253263
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None = None,
264+
is_async_iterable: Callable[[Any], TypeGuard[AsyncIterable]] | None = None,
254265
**custom_args: Any,
255266
) -> list[GraphQLError] | ExecutionContext:
256267
"""Build an execution context
@@ -325,6 +336,7 @@ def build(
325336
enable_early_execution,
326337
middleware_manager,
327338
is_awaitable,
339+
is_async_iterable,
328340
**custom_args,
329341
)
330342

@@ -1041,7 +1053,7 @@ def complete_list_value(
10411053
"""
10421054
item_type = return_type.of_type
10431055

1044-
if isinstance(result, AsyncIterable):
1056+
if self.is_async_iterable(result):
10451057
async_iterator = result.__aiter__()
10461058

10471059
return self.complete_async_iterator_value(
@@ -1582,8 +1594,8 @@ def map_source_to_response(
15821594
as it is nearly identical to the "ExecuteQuery" algorithm,
15831595
for which :func:`~graphql.execution.execute` is also used.
15841596
"""
1585-
if not isinstance(result_or_stream, AsyncIterable):
1586-
return result_or_stream # pragma: no cover
1597+
if not self.is_async_iterable(result_or_stream):
1598+
return cast("ExecutionResult", result_or_stream) # pragma: no cover
15871599

15881600
build_context = self.build_per_event_execution_context
15891601

@@ -2103,6 +2115,7 @@ def execute(
21032115
middleware: Middleware | None = None,
21042116
execution_context_class: type[ExecutionContext] | None = None,
21052117
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None = None,
2118+
is_async_iterable: Callable[[Any], TypeGuard[AsyncIterable]] | None = None,
21062119
**custom_context_args: Any,
21072120
) -> AwaitableOrValue[ExecutionResult]:
21082121
"""Execute a GraphQL operation.
@@ -2137,6 +2150,7 @@ def execute(
21372150
middleware,
21382151
execution_context_class,
21392152
is_awaitable,
2153+
is_async_iterable,
21402154
**custom_context_args,
21412155
)
21422156
if isinstance(result, ExecutionResult):
@@ -2167,6 +2181,7 @@ def experimental_execute_incrementally(
21672181
middleware: Middleware | None = None,
21682182
execution_context_class: type[ExecutionContext] | None = None,
21692183
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None = None,
2184+
is_async_iterable: Callable[[Any], TypeGuard[AsyncIterable]] | None = None,
21702185
**custom_context_args: Any,
21712186
) -> AwaitableOrValue[ExecutionResult | ExperimentalIncrementalExecutionResults]:
21722187
"""Execute GraphQL operation incrementally (internal implementation).
@@ -2197,6 +2212,7 @@ def experimental_execute_incrementally(
21972212
enable_early_execution,
21982213
middleware,
21992214
is_awaitable,
2215+
is_async_iterable,
22002216
**custom_context_args,
22012217
)
22022218

@@ -2692,7 +2708,7 @@ def assert_event_stream(result: Any) -> AsyncIterable:
26922708
raise result
26932709

26942710
# Assert field returned an event stream, otherwise yield an error.
2695-
if not isinstance(result, AsyncIterable):
2711+
if not default_is_async_iterable(result):
26962712
msg = (
26972713
"Subscription field must return AsyncIterable."
26982714
f" Received: {inspect(result)}."

src/graphql/graphql.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
if TYPE_CHECKING:
20-
from collections.abc import Awaitable, Callable
20+
from collections.abc import AsyncIterable, Awaitable, Callable
2121
from typing import TypeGuard
2222

2323
from .pyutils import AwaitableOrValue
@@ -37,6 +37,7 @@ async def graphql(
3737
middleware: Middleware | None = None,
3838
execution_context_class: type[ExecutionContext] | None = None,
3939
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None = None,
40+
is_async_iterable: Callable[[Any], TypeGuard[AsyncIterable]] | None = None,
4041
) -> ExecutionResult:
4142
"""Execute a GraphQL operation asynchronously.
4243
@@ -84,6 +85,8 @@ async def graphql(
8485
The execution context class to use to build the context
8586
:arg is_awaitable:
8687
The predicate to be used for checking whether values are awaitable
88+
:arg is_async_iterable:
89+
The predicate to be used for checking whether values are async iterables
8790
"""
8891
# Always return asynchronously for a consistent API.
8992
result = graphql_impl(
@@ -98,6 +101,7 @@ async def graphql(
98101
middleware,
99102
execution_context_class,
100103
is_awaitable,
104+
is_async_iterable,
101105
)
102106

103107
if default_is_awaitable(result):
@@ -111,6 +115,11 @@ def assume_not_awaitable(_value: Any) -> TypeGuard[Awaitable]:
111115
return False
112116

113117

118+
def assume_not_async_iterable(_value: Any) -> TypeGuard[AsyncIterable]:
119+
"""Replacement for is_async_iterable if everything is assumed to be synchronous."""
120+
return False
121+
122+
114123
def graphql_sync(
115124
schema: GraphQLSchema,
116125
source: str | Source,
@@ -138,6 +147,7 @@ def graphql_sync(
138147
if callable(check_sync)
139148
else (None if check_sync else assume_not_awaitable)
140149
)
150+
is_async_iterable = assume_not_async_iterable if not check_sync else None
141151
result = graphql_impl(
142152
schema,
143153
source,
@@ -150,6 +160,7 @@ def graphql_sync(
150160
middleware,
151161
execution_context_class,
152162
is_awaitable,
163+
is_async_iterable,
153164
)
154165

155166
# Assert that the execution was synchronous.
@@ -173,6 +184,7 @@ def graphql_impl(
173184
middleware: Middleware | None,
174185
execution_context_class: type[ExecutionContext] | None,
175186
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None,
187+
is_async_iterable: Callable[[Any], TypeGuard[AsyncIterable]] | None = None,
176188
) -> AwaitableOrValue[ExecutionResult]:
177189
"""Execute a query, return asynchronously only if necessary."""
178190
# Validate Schema
@@ -206,4 +218,5 @@ def graphql_impl(
206218
middleware,
207219
execution_context_class,
208220
is_awaitable,
221+
is_async_iterable,
209222
)

src/graphql/pyutils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .group_by import group_by
2525
from .identity_func import identity_func
2626
from .inspect import inspect
27-
from .is_awaitable import is_awaitable
27+
from .is_awaitable import is_awaitable, is_async_iterable
2828
from .is_iterable import is_collection, is_iterable
2929
from .natural_compare import natural_comparison_key
3030
from .awaitable_or_value import AwaitableOrValue
@@ -59,6 +59,7 @@
5959
"group_by",
6060
"identity_func",
6161
"inspect",
62+
"is_async_iterable",
6263
"is_awaitable",
6364
"is_collection",
6465
"is_description",

src/graphql/pyutils/is_awaitable.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from typing import TYPE_CHECKING, Any, TypeGuard
88

99
if TYPE_CHECKING:
10-
from collections.abc import Awaitable
10+
from collections.abc import AsyncIterable, Awaitable
1111

12-
__all__ = ["is_awaitable"]
12+
__all__ = ["is_async_iterable", "is_awaitable"]
1313

1414
CO_ITERABLE_COROUTINE = inspect.CO_ITERABLE_COROUTINE
1515

@@ -31,3 +31,12 @@ def is_awaitable(value: Any) -> TypeGuard[Awaitable]:
3131
# check for other awaitables (e.g. futures)
3232
or hasattr(value, "__await__")
3333
)
34+
35+
36+
def is_async_iterable(value: Any) -> TypeGuard[AsyncIterable]:
37+
"""Return True if object is an asynchronous iterable.
38+
39+
Instead of testing whether the object is an instance of abc.AsyncIterable, we
40+
check the existence of an `__aiter__` attribute. This is much faster.
41+
"""
42+
return hasattr(value, "__aiter__")

tests/execution/test_lists.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from collections.abc import AsyncGenerator
2-
from typing import Any
1+
from collections.abc import AsyncGenerator, AsyncIterable, Callable
2+
from typing import TYPE_CHECKING, Any, TypeGuard
33

44
import pytest
55

@@ -204,6 +204,62 @@ async def __anext__(self):
204204
None,
205205
)
206206

207+
@pytest.mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning")
208+
async def can_customize_detection_of_async_iterables():
209+
class CustomIterable:
210+
"""An object that is both an iterable and an async iterable."""
211+
212+
stop = False
213+
214+
def __iter__(self):
215+
return self
216+
217+
def __next__(self):
218+
if self.stop:
219+
raise StopIteration
220+
self.stop = True
221+
return "hello"
222+
223+
def __aiter__(self):
224+
return self
225+
226+
async def __anext__(self):
227+
if self.stop:
228+
raise StopAsyncIteration
229+
self.stop = True
230+
return "world"
231+
232+
assert await _complete(CustomIterable()) == (
233+
{"listField": ["world"]},
234+
None,
235+
)
236+
237+
async def _complete_custom(
238+
is_async_iterable: Callable[[Any], TypeGuard[AsyncIterable]] | None = None,
239+
):
240+
return execute(
241+
build_schema("type Query { listField: [String] }"),
242+
parse("{ listField }"),
243+
Data(CustomIterable()),
244+
is_async_iterable=is_async_iterable,
245+
)
246+
247+
def use_async(result: Any) -> TypeGuard[AsyncIterable]:
248+
return isinstance(result, AsyncIterable)
249+
250+
result = await _complete_custom(use_async)
251+
252+
assert is_awaitable(result)
253+
assert await result == ({"listField": ["world"]}, None)
254+
255+
def use_sync(result: Any) -> TypeGuard[AsyncIterable]:
256+
return isinstance(result, AsyncIterable) and not hasattr(result, "__iter__")
257+
258+
result = await _complete_custom(use_sync)
259+
260+
assert not is_awaitable(result)
261+
assert result == ({"listField": ["hello"]}, None)
262+
207263
async def handles_an_async_generator_that_throws():
208264
async def list_field():
209265
yield "two"

0 commit comments

Comments
 (0)