Skip to content

Commit 1a5aca8

Browse files
committed
Move customization tests for to a separate module
1 parent 1990bdc commit 1a5aca8

File tree

2 files changed

+83
-75
lines changed

2 files changed

+83
-75
lines changed

tests/execution/test_customize.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
1-
from graphql.execution import ExecutionContext, execute
1+
from pytest import mark
2+
3+
from graphql.execution import ExecutionContext, MapAsyncIterator, execute, subscribe
24
from graphql.language import parse
35
from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString
46

57

8+
try:
9+
anext
10+
except NameError: # pragma: no cover (Python < 3.10)
11+
# noinspection PyShadowingBuiltins
12+
async def anext(iterator):
13+
"""Return the next item from an async iterator."""
14+
return await iterator.__anext__()
15+
16+
617
def describe_customize_execution():
718
def uses_a_custom_field_resolver():
819
query = parse("{ foo }")
@@ -39,3 +50,73 @@ def execute_field(self, parent_type, source, field_nodes, path):
3950
{"foo": "barbar"},
4051
None,
4152
)
53+
54+
55+
def describe_customize_subscription():
56+
@mark.asyncio
57+
async def uses_a_custom_subscribe_field_resolver():
58+
schema = GraphQLSchema(
59+
query=GraphQLObjectType("Query", {"foo": GraphQLField(GraphQLString)}),
60+
subscription=GraphQLObjectType(
61+
"Subscription", {"foo": GraphQLField(GraphQLString)}
62+
),
63+
)
64+
65+
class Root:
66+
@staticmethod
67+
async def custom_foo():
68+
yield {"foo": "FooValue"}
69+
70+
subscription = await subscribe(
71+
schema,
72+
document=parse("subscription { foo }"),
73+
root_value=Root(),
74+
subscribe_field_resolver=lambda root, _info: root.custom_foo(),
75+
)
76+
assert isinstance(subscription, MapAsyncIterator)
77+
78+
assert await anext(subscription) == (
79+
{"foo": "FooValue"},
80+
None,
81+
)
82+
83+
await subscription.aclose()
84+
85+
@mark.asyncio
86+
async def uses_a_custom_execution_context_class():
87+
class TestExecutionContext(ExecutionContext):
88+
def build_resolve_info(self, *args, **kwargs):
89+
resolve_info = super().build_resolve_info(*args, **kwargs)
90+
resolve_info.context["foo"] = "bar"
91+
return resolve_info
92+
93+
async def generate_foo(_obj, info):
94+
yield info.context["foo"]
95+
96+
def resolve_foo(message, _info):
97+
return message
98+
99+
schema = GraphQLSchema(
100+
query=GraphQLObjectType("Query", {"foo": GraphQLField(GraphQLString)}),
101+
subscription=GraphQLObjectType(
102+
"Subscription",
103+
{
104+
"foo": GraphQLField(
105+
GraphQLString,
106+
resolve=resolve_foo,
107+
subscribe=generate_foo,
108+
)
109+
},
110+
),
111+
)
112+
113+
document = parse("subscription { foo }")
114+
subscription = await subscribe(
115+
schema,
116+
document,
117+
context_value={},
118+
execution_context_class=TestExecutionContext,
119+
)
120+
assert isinstance(subscription, MapAsyncIterator)
121+
122+
assert await anext(subscription) == ({"foo": "bar"}, None)

tests/execution/test_subscribe.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,7 @@
33

44
from pytest import mark, raises
55

6-
from graphql.execution import (
7-
ExecutionContext,
8-
MapAsyncIterator,
9-
create_source_event_stream,
10-
subscribe,
11-
)
6+
from graphql.execution import MapAsyncIterator, create_source_event_stream, subscribe
127
from graphql.language import parse
138
from graphql.pyutils import SimplePubSub
149
from graphql.type import (
@@ -214,35 +209,6 @@ async def foo_generator(_obj, _info):
214209

215210
await subscription.aclose()
216211

217-
@mark.asyncio
218-
async def uses_a_custom_default_subscribe_field_resolver():
219-
schema = GraphQLSchema(
220-
query=DummyQueryType,
221-
subscription=GraphQLObjectType(
222-
"Subscription", {"foo": GraphQLField(GraphQLString)}
223-
),
224-
)
225-
226-
class Root:
227-
@staticmethod
228-
async def custom_foo():
229-
yield {"foo": "FooValue"}
230-
231-
subscription = await subscribe(
232-
schema,
233-
document=parse("subscription { foo }"),
234-
root_value=Root(),
235-
subscribe_field_resolver=lambda root, _info: root.custom_foo(),
236-
)
237-
assert isinstance(subscription, MapAsyncIterator)
238-
239-
assert await anext(subscription) == (
240-
{"foo": "FooValue"},
241-
None,
242-
)
243-
244-
await subscription.aclose()
245-
246212
@mark.asyncio
247213
async def should_only_resolve_the_first_field_of_invalid_multi_field():
248214
did_resolve = {"foo": False, "bar": False}
@@ -897,42 +863,3 @@ def resolve_message(message, _info):
897863
assert isinstance(subscription, MapAsyncIterator)
898864

899865
assert await anext(subscription) == ({"newMessage": "Hello"}, None)
900-
901-
@mark.asyncio
902-
async def should_work_with_custom_execution_contexts():
903-
class CustomExecutionContext(ExecutionContext):
904-
def build_resolve_info(self, *args, **kwargs):
905-
resolve_info = super().build_resolve_info(*args, **kwargs)
906-
resolve_info.context["foo"] = "bar"
907-
return resolve_info
908-
909-
async def generate_messages(_obj, info):
910-
yield info.context["foo"]
911-
912-
def resolve_message(message, _info):
913-
return message
914-
915-
schema = GraphQLSchema(
916-
query=QueryType,
917-
subscription=GraphQLObjectType(
918-
"Subscription",
919-
{
920-
"newMessage": GraphQLField(
921-
GraphQLString,
922-
resolve=resolve_message,
923-
subscribe=generate_messages,
924-
)
925-
},
926-
),
927-
)
928-
929-
document = parse("subscription { newMessage }")
930-
subscription = await subscribe(
931-
schema,
932-
document,
933-
context_value={},
934-
execution_context_class=CustomExecutionContext,
935-
)
936-
assert isinstance(subscription, MapAsyncIterator)
937-
938-
assert await anext(subscription) == ({"newMessage": "bar"}, None)

0 commit comments

Comments
 (0)