|
1 |
| -from graphql.execution import ExecutionContext, execute |
| 1 | +from pytest import mark |
| 2 | + |
| 3 | +from graphql.execution import ExecutionContext, MapAsyncIterator, execute, subscribe |
2 | 4 | from graphql.language import parse
|
3 | 5 | from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString
|
4 | 6 |
|
5 | 7 |
|
| 8 | +try: |
| 9 | + anext |
| 10 | +except NameError: # pragma: no cover (Python < 3.10) |
| 11 | + # noinspection PyShadowingBuiltins |
| 12 | + async def anext(iterator): |
| 13 | + """Return the next item from an async iterator.""" |
| 14 | + return await iterator.__anext__() |
| 15 | + |
| 16 | + |
6 | 17 | def describe_customize_execution():
|
7 | 18 | def uses_a_custom_field_resolver():
|
8 | 19 | query = parse("{ foo }")
|
@@ -39,3 +50,73 @@ def execute_field(self, parent_type, source, field_nodes, path):
|
39 | 50 | {"foo": "barbar"},
|
40 | 51 | None,
|
41 | 52 | )
|
| 53 | + |
| 54 | + |
| 55 | +def describe_customize_subscription(): |
| 56 | + @mark.asyncio |
| 57 | + async def uses_a_custom_subscribe_field_resolver(): |
| 58 | + schema = GraphQLSchema( |
| 59 | + query=GraphQLObjectType("Query", {"foo": GraphQLField(GraphQLString)}), |
| 60 | + subscription=GraphQLObjectType( |
| 61 | + "Subscription", {"foo": GraphQLField(GraphQLString)} |
| 62 | + ), |
| 63 | + ) |
| 64 | + |
| 65 | + class Root: |
| 66 | + @staticmethod |
| 67 | + async def custom_foo(): |
| 68 | + yield {"foo": "FooValue"} |
| 69 | + |
| 70 | + subscription = await subscribe( |
| 71 | + schema, |
| 72 | + document=parse("subscription { foo }"), |
| 73 | + root_value=Root(), |
| 74 | + subscribe_field_resolver=lambda root, _info: root.custom_foo(), |
| 75 | + ) |
| 76 | + assert isinstance(subscription, MapAsyncIterator) |
| 77 | + |
| 78 | + assert await anext(subscription) == ( |
| 79 | + {"foo": "FooValue"}, |
| 80 | + None, |
| 81 | + ) |
| 82 | + |
| 83 | + await subscription.aclose() |
| 84 | + |
| 85 | + @mark.asyncio |
| 86 | + async def uses_a_custom_execution_context_class(): |
| 87 | + class TestExecutionContext(ExecutionContext): |
| 88 | + def build_resolve_info(self, *args, **kwargs): |
| 89 | + resolve_info = super().build_resolve_info(*args, **kwargs) |
| 90 | + resolve_info.context["foo"] = "bar" |
| 91 | + return resolve_info |
| 92 | + |
| 93 | + async def generate_foo(_obj, info): |
| 94 | + yield info.context["foo"] |
| 95 | + |
| 96 | + def resolve_foo(message, _info): |
| 97 | + return message |
| 98 | + |
| 99 | + schema = GraphQLSchema( |
| 100 | + query=GraphQLObjectType("Query", {"foo": GraphQLField(GraphQLString)}), |
| 101 | + subscription=GraphQLObjectType( |
| 102 | + "Subscription", |
| 103 | + { |
| 104 | + "foo": GraphQLField( |
| 105 | + GraphQLString, |
| 106 | + resolve=resolve_foo, |
| 107 | + subscribe=generate_foo, |
| 108 | + ) |
| 109 | + }, |
| 110 | + ), |
| 111 | + ) |
| 112 | + |
| 113 | + document = parse("subscription { foo }") |
| 114 | + subscription = await subscribe( |
| 115 | + schema, |
| 116 | + document, |
| 117 | + context_value={}, |
| 118 | + execution_context_class=TestExecutionContext, |
| 119 | + ) |
| 120 | + assert isinstance(subscription, MapAsyncIterator) |
| 121 | + |
| 122 | + assert await anext(subscription) == ({"foo": "bar"}, None) |
0 commit comments