Skip to content

Commit 586fcc3

Browse files
committed
subscribe: rewrite async functions
Replicates graphql/graphql-js@3958158
1 parent 47a1651 commit 586fcc3

File tree

1 file changed

+39
-36
lines changed

1 file changed

+39
-36
lines changed

src/graphql/subscription/subscribe.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
8888
)
8989
return await result if isawaitable(result) else result # type: ignore
9090

91+
# Map every source value to a ExecutionResult value as described above.
9192
return MapAsyncIterator(result_or_stream, map_source_to_response)
9293

9394

@@ -127,25 +128,35 @@ async def create_source_event_stream(
127128
# mistake which should throw an early error.
128129
assert_valid_execution_arguments(schema, document, variable_values)
129130

130-
# If a valid context cannot be created due to incorrect arguments, this will throw
131-
# an error.
132-
context = ExecutionContext.build(
133-
schema,
134-
document,
135-
root_value,
136-
context_value,
137-
variable_values,
138-
operation_name,
139-
field_resolver,
140-
)
131+
try:
132+
# If a valid context cannot be created due to incorrect arguments,
133+
# this will throw an error.
134+
context = ExecutionContext.build(
135+
schema,
136+
document,
137+
root_value,
138+
context_value,
139+
variable_values,
140+
operation_name,
141+
field_resolver,
142+
)
141143

142-
# Return early errors if execution context failed.
143-
if isinstance(context, list):
144-
return ExecutionResult(data=None, errors=context)
144+
# Return early errors if execution context failed.
145+
if isinstance(context, list):
146+
return ExecutionResult(data=None, errors=context)
147+
148+
event_stream = await execute_subscription(context)
149+
150+
# Assert field returned an event stream, otherwise yield an error.
151+
if not isinstance(event_stream, AsyncIterable):
152+
raise TypeError(
153+
"Subscription field must return AsyncIterable."
154+
f" Received: {inspect(event_stream)}."
155+
)
156+
return event_stream
145157

146-
try:
147-
return await execute_subscription(context)
148158
except GraphQLError as error:
159+
# Report it as an ExecutionResult, containing only errors and no data.
149160
return ExecutionResult(data=None, errors=[error])
150161

151162

@@ -171,29 +182,21 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
171182
# Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
172183
# It differs from "ResolveFieldValue" due to providing a different `resolveFn`.
173184

174-
# Build a dictionary of arguments from the field.arguments AST, using the
175-
# variables scope to fulfill any variable references.
176-
args = get_argument_values(field_def, field_nodes[0], context.variable_values)
185+
try:
186+
# Build a dictionary of arguments from the field.arguments AST, using the
187+
# variables scope to fulfill any variable references.
188+
args = get_argument_values(field_def, field_nodes[0], context.variable_values)
177189

178-
# Call the `subscribe()` resolver or the default resolver to produce an
179-
# AsyncIterable yielding raw payloads.
180-
resolve_fn = field_def.subscribe or context.field_resolver
190+
# Call the `subscribe()` resolver or the default resolver to produce an
191+
# AsyncIterable yielding raw payloads.
192+
resolve_fn = field_def.subscribe or context.field_resolver
181193

182-
try:
183194
event_stream = resolve_fn(context.root_value, info, **args)
184195
if context.is_awaitable(event_stream):
185196
event_stream = await event_stream
186-
except Exception as error:
187-
event_stream = error
188-
189-
if isinstance(event_stream, Exception):
190-
raise located_error(event_stream, field_nodes, path.as_list())
197+
if isinstance(event_stream, Exception):
198+
raise event_stream
191199

192-
# Assert field returned an event stream, otherwise yield an error.
193-
if not isinstance(event_stream, AsyncIterable):
194-
raise TypeError(
195-
"Subscription field must return AsyncIterable."
196-
f" Received: {inspect(event_stream)}."
197-
)
198-
199-
return event_stream
200+
return event_stream
201+
except Exception as error:
202+
raise located_error(error, field_nodes, path.as_list())

0 commit comments

Comments
 (0)