Skip to content

Commit 613fe7b

Browse files
committed
initial test pass
1 parent b759723 commit 613fe7b

File tree

2 files changed

+88
-47
lines changed

2 files changed

+88
-47
lines changed

src/graphql/execution/execute.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,6 +2043,7 @@ def subscribe(
20432043
type_resolver: GraphQLTypeResolver | None = None,
20442044
subscribe_field_resolver: GraphQLFieldResolver | None = None,
20452045
execution_context_class: type[ExecutionContext] | None = None,
2046+
middleware: MiddlewareManager | None = None,
20462047
) -> AwaitableOrValue[AsyncIterator[ExecutionResult] | ExecutionResult]:
20472048
"""Create a GraphQL subscription.
20482049
@@ -2082,6 +2083,7 @@ def subscribe(
20822083
field_resolver,
20832084
type_resolver,
20842085
subscribe_field_resolver,
2086+
middleware=middleware,
20852087
)
20862088

20872089
# Return early errors if execution context failed.

tests/execution/test_middleware.py

Lines changed: 86 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
1+
import inspect
12
from typing import Awaitable, cast
23

34
import pytest
5+
from graphql import subscribe
46
from graphql.execution import Middleware, MiddlewareManager, execute
57
from graphql.language.parser import parse
68
from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString
79

8-
def _create_schema(tp: GraphQLObjectType, is_subscription: bool) -> GraphQLSchema:
10+
11+
def _create_schema(
12+
tp: GraphQLObjectType, is_subscription: bool = False
13+
) -> GraphQLSchema:
914
if is_subscription:
10-
noop_type = GraphQLObjectType("Noop", {"noop": GraphQLField(GraphQLString, resolve=lambda *_: "noop")})
15+
noop_type = GraphQLObjectType(
16+
"Noop", {"noop": GraphQLField(GraphQLString, resolve=lambda *_: "noop")}
17+
)
1118
return GraphQLSchema(query=noop_type, subscription=tp)
1219
return GraphQLSchema(tp)
13-
@pytest.mark.parametrize("is_subscription", [False, True], ids=["query", "subscription"])
14-
def test_describe_middleware(is_subscription: bool):
1520

16-
def test_test_describe_with_manager():
17-
def test_default():
21+
22+
def describe_middleware():
23+
def describe_with_manager():
24+
def default():
1825
doc = parse("{ field }")
1926

2027
# noinspection PyMethodMayBeStatic
2128
class Data:
22-
def test_field(self, _info):
29+
def field(self, _info):
2330
return "resolved"
2431

2532
test_type = GraphQLObjectType(
@@ -28,20 +35,20 @@ def test_field(self, _info):
2835

2936
middlewares = MiddlewareManager()
3037
result = execute(
31-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
38+
_create_schema(test_type), doc, Data(), middleware=middlewares
3239
)
3340

3441
assert result.data["field"] == "resolved" # type: ignore
3542

36-
def test_single_function():
43+
def single_function():
3744
doc = parse("{ first second }")
3845

3946
# noinspection PyMethodMayBeStatic
4047
class Data:
41-
def test_first(self, _info):
48+
def first(self, _info):
4249
return "one"
4350

44-
def test_second(self, _info):
51+
def second(self, _info):
4552
return "two"
4653

4754
test_type = GraphQLObjectType(
@@ -57,12 +64,12 @@ def reverse_middleware(next_, *args, **kwargs):
5764

5865
middlewares = MiddlewareManager(reverse_middleware)
5966
result = execute(
60-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
67+
_create_schema(test_type), doc, Data(), middleware=middlewares
6168
)
6269

6370
assert result.data == {"first": "eno", "second": "owt"} # type: ignore
6471

65-
def test_two_functions_and_field_resolvers():
72+
def two_functions_and_field_resolvers():
6673
doc = parse("{ first second }")
6774

6875
# noinspection PyMethodMayBeStatic
@@ -90,21 +97,21 @@ def capitalize_middleware(next_, *args, **kwargs):
9097

9198
middlewares = MiddlewareManager(reverse_middleware, capitalize_middleware)
9299
result = execute(
93-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
100+
_create_schema(test_type), doc, Data(), middleware=middlewares
94101
)
95102

96103
assert result.data == {"first": "Eno", "second": "Owt"} # type: ignore
97104

98105
@pytest.mark.asyncio()
99-
async def test_single_async_function():
106+
async def single_async_function():
100107
doc = parse("{ first second }")
101108

102109
# noinspection PyMethodMayBeStatic
103110
class Data:
104-
async def test_first(self, _info):
111+
async def first(self, _info):
105112
return "one"
106113

107-
async def test_second(self, _info):
114+
async def second(self, _info):
108115
return "two"
109116

110117
test_type = GraphQLObjectType(
@@ -120,21 +127,21 @@ async def reverse_middleware(next_, *args, **kwargs):
120127

121128
middlewares = MiddlewareManager(reverse_middleware)
122129
awaitable_result = execute(
123-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
130+
_create_schema(test_type), doc, Data(), middleware=middlewares
124131
)
125132
assert isinstance(awaitable_result, Awaitable)
126133
result = await awaitable_result
127134
assert result.data == {"first": "eno", "second": "owt"}
128135

129-
def test_single_object():
136+
def single_object():
130137
doc = parse("{ first second }")
131138

132139
# noinspection PyMethodMayBeStatic
133140
class Data:
134-
def test_first(self, _info):
141+
def first(self, _info):
135142
return "one"
136143

137-
def test_second(self, _info):
144+
def second(self, _info):
138145
return "two"
139146

140147
test_type = GraphQLObjectType(
@@ -147,17 +154,17 @@ def test_second(self, _info):
147154

148155
class ReverseMiddleware:
149156
# noinspection PyMethodMayBeStatic
150-
def test_resolve(self, next_, *args, **kwargs):
157+
def resolve(self, next_, *args, **kwargs):
151158
return next_(*args, **kwargs)[::-1]
152159

153160
middlewares = MiddlewareManager(ReverseMiddleware())
154161
result = execute(
155-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
162+
_create_schema(test_type), doc, Data(), middleware=middlewares
156163
)
157164

158165
assert result.data == {"first": "eno", "second": "owt"} # type: ignore
159166

160-
def test_skip_middleware_without_resolve_method():
167+
def skip_middleware_without_resolve_method():
161168
class BadMiddleware:
162169
pass # no resolve method here
163170

@@ -173,12 +180,12 @@ class BadMiddleware:
173180
middleware=MiddlewareManager(BadMiddleware()),
174181
) == ({"foo": "bar"}, None)
175182

176-
def test_with_function_and_object():
183+
def with_function_and_object():
177184
doc = parse("{ field }")
178185

179186
# noinspection PyMethodMayBeStatic
180187
class Data:
181-
def test_field(self, _info):
188+
def field(self, _info):
182189
return "resolved"
183190

184191
test_type = GraphQLObjectType(
@@ -190,28 +197,28 @@ def reverse_middleware(next_, *args, **kwargs):
190197

191198
class CaptitalizeMiddleware:
192199
# noinspection PyMethodMayBeStatic
193-
def test_resolve(self, next_, *args, **kwargs):
200+
def resolve(self, next_, *args, **kwargs):
194201
return next_(*args, **kwargs).capitalize()
195202

196203
middlewares = MiddlewareManager(reverse_middleware, CaptitalizeMiddleware())
197204
result = execute(
198-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
205+
_create_schema(test_type), doc, Data(), middleware=middlewares
199206
)
200207
assert result.data == {"field": "Devloser"} # type: ignore
201208

202209
middlewares = MiddlewareManager(CaptitalizeMiddleware(), reverse_middleware)
203210
result = execute(
204-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
211+
_create_schema(test_type), doc, Data(), middleware=middlewares
205212
)
206213
assert result.data == {"field": "devloseR"} # type: ignore
207214

208215
@pytest.mark.asyncio()
209-
async def test_with_async_function_and_object():
216+
async def with_async_function_and_object():
210217
doc = parse("{ field }")
211218

212219
# noinspection PyMethodMayBeStatic
213220
class Data:
214-
async def test_field(self, _info):
221+
async def field(self, _info):
215222
return "resolved"
216223

217224
test_type = GraphQLObjectType(
@@ -228,54 +235,86 @@ async def test_resolve(self, next_, *args, **kwargs):
228235

229236
middlewares = MiddlewareManager(reverse_middleware, CaptitalizeMiddleware())
230237
awaitable_result = execute(
231-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
238+
_create_schema(test_type), doc, Data(), middleware=middlewares
232239
)
233240
assert isinstance(awaitable_result, Awaitable)
234241
result = await awaitable_result
235242
assert result.data == {"field": "Devloser"}
236243

237244
middlewares = MiddlewareManager(CaptitalizeMiddleware(), reverse_middleware)
238245
awaitable_result = execute(
239-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
246+
_create_schema(test_type), doc, Data(), middleware=middlewares
240247
)
241248
assert isinstance(awaitable_result, Awaitable)
242249
result = await awaitable_result
243250
assert result.data == {"field": "devloseR"}
244251

245-
def test_describe_without_manager():
246-
def test_no_middleware():
252+
@pytest.mark.asyncio()
253+
async def subscription_simple():
254+
async def bar_resolve(_obj, _info):
255+
yield "bar"
256+
257+
test_type = GraphQLObjectType(
258+
"Subscription",
259+
{
260+
"bar": GraphQLField(
261+
GraphQLString,
262+
resolve=lambda message, _: message,
263+
subscribe=bar_resolve,
264+
),
265+
},
266+
)
267+
doc = parse("subscription { bar }")
268+
269+
async def reverse_middleware(next_, value, info, **kwargs):
270+
awaitable_maybe = next_(value, info, **kwargs)
271+
if inspect.isawaitable(awaitable_maybe):
272+
return (await awaitable_maybe)[::-1]
273+
return awaitable_maybe[::-1]
274+
275+
agen = subscribe(
276+
_create_schema(test_type, is_subscription=True),
277+
doc,
278+
middleware=MiddlewareManager(reverse_middleware),
279+
)
280+
assert inspect.isasyncgen(agen)
281+
data = (await agen.__anext__()).data
282+
assert data == {"bar": "rab"}
283+
284+
def describe_without_manager():
285+
def no_middleware():
247286
doc = parse("{ field }")
248287

249288
# noinspection PyMethodMayBeStatic
250289
class Data:
251-
def test_field(self, _info):
290+
def field(self, _info):
252291
return "resolved"
253292

254293
test_type = GraphQLObjectType(
255294
"TestType", {"field": GraphQLField(GraphQLString)}
256295
)
257296

258-
result = execute(_create_schema(test_type, is_subscription), doc, Data(), middleware=None)
297+
result = execute(_create_schema(test_type), doc, Data(), middleware=None)
259298

260299
assert result.data["field"] == "resolved" # type: ignore
261300

262-
def test_empty_middleware_list():
301+
def empty_middleware_list():
263302
doc = parse("{ field }")
264303

265304
# noinspection PyMethodMayBeStatic
266305
class Data:
267-
def test_field(self, _info):
306+
def field(self, _info):
268307
return "resolved"
269308

270309
test_type = GraphQLObjectType(
271310
"TestType", {"field": GraphQLField(GraphQLString)}
272311
)
273312

274-
result = execute(_create_schema(test_type, is_subscription), doc, Data(), middleware=[])
313+
result = execute(_create_schema(test_type), doc, Data(), middleware=[])
275314

276315
assert result.data["field"] == "resolved" # type: ignore
277316

278-
def test_bad_middleware_object():
317+
def bad_middleware_object():
279318
doc = parse("{ field }")
280319

281320
test_type = GraphQLObjectType(
@@ -285,7 +324,7 @@ def test_bad_middleware_object():
285324
with pytest.raises(TypeError) as exc_info:
286325
# noinspection PyTypeChecker
287326
execute(
288-
_create_schema(test_type, is_subscription),
327+
_create_schema(test_type),
289328
doc,
290329
None,
291330
middleware=cast(Middleware, {"bad": "value"}),
@@ -297,12 +336,12 @@ def test_bad_middleware_object():
297336
" Got {'bad': 'value'} instead."
298337
)
299338

300-
def test_list_of_functions():
339+
def list_of_functions():
301340
doc = parse("{ field }")
302341

303342
# noinspection PyMethodMayBeStatic
304343
class Data:
305-
def test_field(self, _info):
344+
def field(self, _info):
306345
return "resolved"
307346

308347
test_type = GraphQLObjectType(
@@ -312,11 +351,11 @@ def test_field(self, _info):
312351
log = []
313352

314353
class LogMiddleware:
315-
def test___init__(self, name):
354+
def __init__(self, name):
316355
self.name = name
317356

318357
# noinspection PyMethodMayBeStatic
319-
def test_resolve(self, next_, *args, **kwargs):
358+
def resolve(self, next_, *args, **kwargs):
320359
log.append(f"enter {self.name}")
321360
value = next_(*args, **kwargs)
322361
log.append(f"exit {self.name}")
@@ -325,7 +364,7 @@ def test_resolve(self, next_, *args, **kwargs):
325364
middlewares = [LogMiddleware("A"), LogMiddleware("B"), LogMiddleware("C")]
326365

327366
result = execute(
328-
_create_schema(test_type, is_subscription), doc, Data(), middleware=middlewares
367+
_create_schema(test_type), doc, Data(), middleware=middlewares
329368
)
330369
assert result.data == {"field": "resolved"} # type: ignore
331370

0 commit comments

Comments
 (0)