Skip to content

Commit 8ebf563

Browse files
authored
Support custom execution contexts in subscriptions (#181)
1 parent c4e872d commit 8ebf563

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

src/graphql/execution/subscribe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from inspect import isawaitable
2-
from typing import Any, AsyncIterable, AsyncIterator, Dict, Optional, Union
2+
from typing import Any, AsyncIterable, AsyncIterator, Dict, Optional, Type, Union
33

44
from ..error import GraphQLError, located_error
55
from ..execution.collect_fields import collect_fields
@@ -29,6 +29,7 @@ async def subscribe(
2929
operation_name: Optional[str] = None,
3030
field_resolver: Optional[GraphQLFieldResolver] = None,
3131
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
32+
execution_context_class: Optional[Type["ExecutionContext"]] = None,
3233
) -> Union[AsyncIterator[ExecutionResult], ExecutionResult]:
3334
"""Create a GraphQL subscription.
3435
@@ -57,6 +58,7 @@ async def subscribe(
5758
variable_values,
5859
operation_name,
5960
subscribe_field_resolver,
61+
execution_context_class,
6062
)
6163
if isinstance(result_or_stream, ExecutionResult):
6264
return result_or_stream
@@ -79,6 +81,7 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
7981
variable_values,
8082
operation_name,
8183
field_resolver,
84+
execution_context_class=execution_context_class,
8285
)
8386
return await result if isawaitable(result) else result
8487

@@ -94,6 +97,7 @@ async def create_source_event_stream(
9497
variable_values: Optional[Dict[str, Any]] = None,
9598
operation_name: Optional[str] = None,
9699
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
100+
execution_context_class: Optional[Type["ExecutionContext"]] = None,
97101
) -> Union[AsyncIterable[Any], ExecutionResult]:
98102
"""Create source event stream
99103
@@ -122,9 +126,12 @@ async def create_source_event_stream(
122126
# mistake which should throw an early error.
123127
assert_valid_execution_arguments(schema, document, variable_values)
124128

129+
if not execution_context_class:
130+
execution_context_class = ExecutionContext
131+
125132
# If a valid context cannot be created due to incorrect arguments,
126133
# a "Response" with only errors is returned.
127-
context = ExecutionContext.build(
134+
context = execution_context_class.build(
128135
schema,
129136
document,
130137
root_value,

tests/execution/test_subscribe.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
from pytest import mark, raises
55

6-
from graphql.execution import MapAsyncIterator, create_source_event_stream, subscribe
6+
from graphql.execution import (
7+
create_source_event_stream,
8+
subscribe,
9+
MapAsyncIterator,
10+
ExecutionContext,
11+
)
712
from graphql.language import parse
813
from graphql.pyutils import SimplePubSub
914
from graphql.type import (
@@ -892,3 +897,42 @@ def resolve_message(message, _info):
892897
assert isinstance(subscription, MapAsyncIterator)
893898

894899
assert await anext(subscription) == ({"newMessage": "Hello"}, None)
900+
901+
@mark.asyncio
902+
async def should_work_with_custom_execution_contexts():
903+
class CustomExecutionContext(ExecutionContext):
904+
def build_resolve_info(self, *args, **kwargs):
905+
resolve_info = super().build_resolve_info(*args, **kwargs)
906+
resolve_info.context['foo'] = 'bar'
907+
return resolve_info
908+
909+
async def generate_messages(_obj, info):
910+
yield info.context['foo']
911+
912+
def resolve_message(message, _info):
913+
return message
914+
915+
schema = GraphQLSchema(
916+
query=QueryType,
917+
subscription=GraphQLObjectType(
918+
"Subscription",
919+
{
920+
"newMessage": GraphQLField(
921+
GraphQLString,
922+
resolve=resolve_message,
923+
subscribe=generate_messages,
924+
)
925+
},
926+
),
927+
)
928+
929+
document = parse("subscription { newMessage }")
930+
subscription = await subscribe(
931+
schema,
932+
document,
933+
context_value={},
934+
execution_context_class=CustomExecutionContext
935+
)
936+
assert isinstance(subscription, MapAsyncIterator)
937+
938+
assert await anext(subscription) == ({"newMessage": "bar"}, None)

0 commit comments

Comments
 (0)