Skip to content

Commit 45c9e38

Browse files
committed
execute/subscribe: simplify to improve debugging experience
Replicates graphql/graphql-js@cadcef8
1 parent 8845f38 commit 45c9e38

File tree

4 files changed

+111
-107
lines changed

4 files changed

+111
-107
lines changed

src/graphql/error/located_error.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import TYPE_CHECKING, Collection, Optional, Union
22

3+
from ..pyutils import inspect
34
from .graphql_error import GraphQLError
45

56
if TYPE_CHECKING:
@@ -9,7 +10,7 @@
910

1011

1112
def located_error(
12-
original_error: Union[Exception, GraphQLError],
13+
original_error: Exception,
1314
nodes: Optional[Union["None", Collection["Node"]]],
1415
path: Optional[Collection[Union[str, int]]] = None,
1516
) -> GraphQLError:
@@ -19,8 +20,9 @@ def located_error(
1920
GraphQL operation, produce a new GraphQLError aware of the location in the document
2021
responsible for the original Exception.
2122
"""
23+
# Sometimes a non-error is thrown, wrap it as a TypeError to ensure consistency.
2224
if not isinstance(original_error, Exception):
23-
raise TypeError("Expected an Exception.")
25+
original_error = TypeError(f"Unexpected error value: {inspect(original_error)}")
2426
# Note: this uses a brand-check to support GraphQL errors originating from
2527
# other contexts.
2628
if isinstance(original_error, GraphQLError) and original_error.path is not None:

src/graphql/execution/execute.py

Lines changed: 55 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,6 @@ def execute_operation(
331331
# Errors from sub-fields of a NonNull type may propagate to the top level, at
332332
# which point we still log the error and null the parent field, which in this
333333
# case is the entire response.
334-
#
335-
# Similar to complete_value_catching_error.
336334
try:
337335
# noinspection PyArgumentList
338336
result = (
@@ -595,6 +593,7 @@ def resolve_field(
595593
if not field_def:
596594
return Undefined
597595

596+
return_type = field_def.type
598597
resolve_fn = field_def.resolve or self.field_resolver
599598

600599
if self.middleware_manager:
@@ -604,29 +603,6 @@ def resolve_field(
604603

605604
# Get the resolve function, regardless of if its result is normal or abrupt
606605
# (error).
607-
result = self.resolve_field_value_or_error(
608-
field_def, field_nodes, resolve_fn, source, info
609-
)
610-
611-
return self.complete_value_catching_error(
612-
field_def.type, field_nodes, info, path, result
613-
)
614-
615-
def resolve_field_value_or_error(
616-
self,
617-
field_def: GraphQLField,
618-
field_nodes: List[FieldNode],
619-
resolve_fn: GraphQLFieldResolver,
620-
source: Any,
621-
info: GraphQLResolveInfo,
622-
) -> Union[Exception, Any]:
623-
"""Resolve field to a value or an error.
624-
625-
Isolates the "ReturnOrAbrupt" behavior to not de-opt the resolve_field()
626-
method. Returns the result of resolveFn or the abrupt-return Error object.
627-
628-
For internal use only.
629-
"""
630606
try:
631607
# Build a dictionary of arguments from the field.arguments AST, using the
632608
# variables scope to fulfill any variable references.
@@ -635,58 +611,38 @@ def resolve_field_value_or_error(
635611
# Note that contrary to the JavaScript implementation, we pass the context
636612
# value as part of the resolve info.
637613
result = resolve_fn(source, info, **args)
614+
615+
completed: AwaitableOrValue[Any]
638616
if self.is_awaitable(result):
639617
# noinspection PyShadowingNames
640618
async def await_result() -> Any:
641619
try:
642-
return await result
620+
completed = self.complete_value(
621+
return_type, field_nodes, info, path, await result
622+
)
623+
if self.is_awaitable(completed):
624+
return await completed
625+
return completed
643626
except Exception as error:
644-
return error
627+
self.handle_field_error(error, field_nodes, path, return_type)
628+
return None
645629

646630
return await_result()
647-
return result
648-
except Exception as error:
649-
return error
650631

651-
def complete_value_catching_error(
652-
self,
653-
return_type: GraphQLOutputType,
654-
field_nodes: List[FieldNode],
655-
info: GraphQLResolveInfo,
656-
path: Path,
657-
result: Any,
658-
) -> AwaitableOrValue[Any]:
659-
"""Complete a value while catching an error.
660-
661-
This is a small wrapper around completeValue which detects and logs errors in
662-
the execution context.
663-
"""
664-
completed: AwaitableOrValue[Any]
665-
try:
666-
if self.is_awaitable(result):
667-
668-
async def await_result() -> Any:
669-
value = self.complete_value(
670-
return_type, field_nodes, info, path, await result
671-
)
672-
if self.is_awaitable(value):
673-
return await value
674-
return value
675-
676-
completed = await_result()
677-
else:
678-
completed = self.complete_value(
679-
return_type, field_nodes, info, path, result
680-
)
632+
completed = self.complete_value(
633+
return_type, field_nodes, info, path, result
634+
)
681635
if self.is_awaitable(completed):
682636
# noinspection PyShadowingNames
683637
async def await_completed() -> Any:
684638
try:
685639
return await completed
686640
except Exception as error:
687641
self.handle_field_error(error, field_nodes, path, return_type)
642+
return None
688643

689644
return await_completed()
645+
690646
return completed
691647
except Exception as error:
692648
self.handle_field_error(error, field_nodes, path, return_type)
@@ -825,10 +781,45 @@ def complete_list_value(
825781
for index, item in enumerate(result):
826782
# No need to modify the info object containing the path, since from here on
827783
# it is not ever accessed by resolver functions.
828-
field_path = path.add_key(index, None)
829-
completed_item = self.complete_value_catching_error(
830-
item_type, field_nodes, info, field_path, item
831-
)
784+
item_path = path.add_key(index, None)
785+
completed_item: AwaitableOrValue[Any]
786+
if is_awaitable(item):
787+
# noinspection PyShadowingNames
788+
async def await_completed(item: Any, item_path: Path) -> Any:
789+
try:
790+
completed = self.complete_value(
791+
item_type, field_nodes, info, item_path, await item
792+
)
793+
if is_awaitable(completed):
794+
return await completed
795+
return completed
796+
except Exception as error:
797+
self.handle_field_error(
798+
error, field_nodes, item_path, item_type
799+
)
800+
return None
801+
802+
completed_item = await_completed(item, item_path)
803+
else:
804+
try:
805+
completed_item = self.complete_value(
806+
item_type, field_nodes, info, item_path, item
807+
)
808+
if is_awaitable(completed_item):
809+
# noinspection PyShadowingNames
810+
async def await_completed(item: Any, item_path: Path) -> Any:
811+
try:
812+
return await item
813+
except Exception as error:
814+
self.handle_field_error(
815+
error, field_nodes, item_path, item_type
816+
)
817+
return None
818+
819+
completed_item = await_completed(completed_item, item_path)
820+
except Exception as error:
821+
self.handle_field_error(error, field_nodes, item_path, item_type)
822+
completed_item = None
832823

833824
if is_awaitable(completed_item):
834825
append_awaitable(index)

src/graphql/subscription/subscribe.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
Any,
44
AsyncIterable,
55
AsyncIterator,
6-
Awaitable,
76
Dict,
87
Optional,
98
Union,
10-
cast,
119
)
1210

1311
from ..error import GraphQLError, located_error
@@ -18,6 +16,7 @@
1816
ExecutionContext,
1917
ExecutionResult,
2018
)
19+
from ..execution.values import get_argument_values
2120
from ..language import DocumentNode
2221
from ..pyutils import Path, inspect
2322
from ..type import GraphQLFieldResolver, GraphQLSchema
@@ -56,18 +55,15 @@ async def subscribe(
5655
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
5756
a stream of ExecutionResults representing the response stream.
5857
"""
59-
try:
60-
result_or_stream = await create_source_event_stream(
61-
schema,
62-
document,
63-
root_value,
64-
context_value,
65-
variable_values,
66-
operation_name,
67-
subscribe_field_resolver,
68-
)
69-
except GraphQLError as error:
70-
return ExecutionResult(data=None, errors=[error])
58+
result_or_stream = await create_source_event_stream(
59+
schema,
60+
document,
61+
root_value,
62+
context_value,
63+
variable_values,
64+
operation_name,
65+
subscribe_field_resolver,
66+
)
7167
if isinstance(result_or_stream, ExecutionResult):
7268
return result_or_stream
7369

@@ -111,11 +107,15 @@ async def create_source_event_stream(
111107
112108
Returns a coroutine that yields an AsyncIterable.
113109
114-
If the client provided invalid arguments, the source stream could not be created,
115-
or the resolver did not return an AsyncIterable, this function will throw an error,
116-
which should be caught and handled by the caller.
110+
If the client-provided arguments to this function do not result in a compliant
111+
subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
112+
data will be returned.
117113
118-
A Source Event Stream represents a sequence of events, each of which triggers a
114+
If the source stream could not be created due to faulty subscription resolver logic
115+
or underlying systems, the coroutine object will yield a single ExecutionResult
116+
containing ``errors`` and no ``data``.
117+
118+
A source event stream represents a sequence of events, each of which triggers a
119119
GraphQL execution for that event.
120120
121121
This may be useful when hosting the stateful subscription service in a different
@@ -143,6 +143,14 @@ async def create_source_event_stream(
143143
if isinstance(context, list):
144144
return ExecutionResult(data=None, errors=context)
145145

146+
try:
147+
return await execute_subscription(context)
148+
except GraphQLError as error:
149+
return ExecutionResult(data=None, errors=[error])
150+
151+
152+
async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
153+
schema = context.schema
146154
type_ = get_operation_root_type(schema, context.operation)
147155
fields = context.collect_fields(type_, context.operation.selection_set, {}, set())
148156
response_names = list(fields)
@@ -157,29 +165,35 @@ async def create_source_event_stream(
157165
f"The subscription field '{field_name}' is not defined.", field_nodes
158166
)
159167

168+
path = Path(None, response_name, type_.name)
169+
info = context.build_resolve_info(field_def, field_nodes, type_, path)
170+
171+
# Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
172+
# It differs from "ResolveFieldValue" due to providing a different `resolveFn`.
173+
174+
# Build a dictionary of arguments from the field.arguments AST, using the
175+
# variables scope to fulfill any variable references.
176+
args = get_argument_values(field_def, field_nodes[0], context.variable_values)
177+
160178
# Call the `subscribe()` resolver or the default resolver to produce an
161179
# AsyncIterable yielding raw payloads.
162180
resolve_fn = field_def.subscribe or context.field_resolver
163181

164-
path = Path(None, response_name, type_.name)
165-
166-
info = context.build_resolve_info(field_def, field_nodes, type_, path)
182+
try:
183+
event_stream = resolve_fn(context.root_value, info, **args)
184+
if context.is_awaitable(event_stream):
185+
event_stream = await event_stream
186+
except Exception as error:
187+
event_stream = error
167188

168-
# `resolve_field_value_or_error` implements the "ResolveFieldEventStream" algorithm
169-
# from GraphQL specification. It differs from `resolve_field_value` due to
170-
# providing a different `resolve_fn`.
171-
result = context.resolve_field_value_or_error(
172-
field_def, field_nodes, resolve_fn, root_value, info
173-
)
174-
event_stream = await cast(Awaitable, result) if isawaitable(result) else result
175-
# If `event_stream` is an Error, rethrow a located error.
176189
if isinstance(event_stream, Exception):
177190
raise located_error(event_stream, field_nodes, path.as_list())
178191

179192
# Assert field returned an event stream, otherwise yield an error.
180-
if isinstance(event_stream, AsyncIterable):
181-
return event_stream
182-
raise TypeError(
183-
"Subscription field must return AsyncIterable."
184-
f" Received: {inspect(event_stream)}."
185-
)
193+
if not isinstance(event_stream, AsyncIterable):
194+
raise TypeError(
195+
"Subscription field must return AsyncIterable."
196+
f" Received: {inspect(event_stream)}."
197+
)
198+
199+
return event_stream

tests/error/test_located_error.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
from typing import cast, Any
22

3-
from pytest import raises # type: ignore
4-
53
from graphql.error import GraphQLError, located_error
64

75

86
def describe_located_error():
97
def throws_without_an_original_error():
10-
with raises(TypeError) as exc_info:
11-
# noinspection PyTypeChecker
12-
located_error([], [], []) # type: ignore
13-
assert str(exc_info.value) == "Expected an Exception."
8+
e = located_error([], [], []).original_error # type: ignore
9+
assert isinstance(e, TypeError)
10+
assert str(e) == "Unexpected error value: []"
1411

1512
def passes_graphql_error_through():
1613
path = ["path", 3, "to", "field"]

0 commit comments

Comments
 (0)