1
1
from inspect import isawaitable
2
- from typing import Any , AsyncIterable , AsyncIterator , Dict , Optional , Union
2
+ from typing import Any , AsyncIterable , AsyncIterator , Dict , Optional , Type , Union
3
3
4
4
from ..error import GraphQLError , located_error
5
5
from ..execution .collect_fields import collect_fields
@@ -29,6 +29,7 @@ async def subscribe(
29
29
operation_name : Optional [str ] = None ,
30
30
field_resolver : Optional [GraphQLFieldResolver ] = None ,
31
31
subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
32
+ execution_context_class : Optional [Type ["ExecutionContext" ]] = None ,
32
33
) -> Union [AsyncIterator [ExecutionResult ], ExecutionResult ]:
33
34
"""Create a GraphQL subscription.
34
35
@@ -57,6 +58,7 @@ async def subscribe(
57
58
variable_values ,
58
59
operation_name ,
59
60
subscribe_field_resolver ,
61
+ execution_context_class ,
60
62
)
61
63
if isinstance (result_or_stream , ExecutionResult ):
62
64
return result_or_stream
@@ -79,6 +81,7 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
79
81
variable_values ,
80
82
operation_name ,
81
83
field_resolver ,
84
+ execution_context_class = execution_context_class ,
82
85
)
83
86
return await result if isawaitable (result ) else result
84
87
@@ -94,6 +97,7 @@ async def create_source_event_stream(
94
97
variable_values : Optional [Dict [str , Any ]] = None ,
95
98
operation_name : Optional [str ] = None ,
96
99
subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
100
+ execution_context_class : Optional [Type ["ExecutionContext" ]] = None ,
97
101
) -> Union [AsyncIterable [Any ], ExecutionResult ]:
98
102
"""Create source event stream
99
103
@@ -122,9 +126,12 @@ async def create_source_event_stream(
122
126
# mistake which should throw an early error.
123
127
assert_valid_execution_arguments (schema , document , variable_values )
124
128
129
+ if not execution_context_class :
130
+ execution_context_class = ExecutionContext
131
+
125
132
# If a valid context cannot be created due to incorrect arguments,
126
133
# a "Response" with only errors is returned.
127
- context = ExecutionContext .build (
134
+ context = execution_context_class .build (
128
135
schema ,
129
136
document ,
130
137
root_value ,
0 commit comments