Skip to content

Commit c7788cb

Browse files
committed
Make subscriptions use as well the document node
1 parent 7795355 commit c7788cb

File tree

4 files changed

+36
-19
lines changed

4 files changed

+36
-19
lines changed

src/graphql_server/http/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4+
from graphql.language import DocumentNode
45
from typing import TYPE_CHECKING, Any, Optional
56
from typing_extensions import Literal, TypedDict
67

@@ -30,10 +31,11 @@ class GraphQLRequestData:
3031
# query is optional here as it can be added by an extensions
3132
# (for example an extension for persisted queries)
3233
query: Optional[str]
34+
document: Optional[DocumentNode]
3335
variables: Optional[dict[str, Any]]
3436
operation_name: Optional[str]
3537
extensions: Optional[dict[str, Any]]
36-
protocol: Literal["http", "multipart-subscription"] = "http"
38+
protocol: Literal["http", "multipart-subscription", "subscription"] = "http"
3739

3840

3941
__all__ = [

src/graphql_server/http/async_base_view.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ async def execute_operation(
218218
if request_data.protocol == "multipart-subscription":
219219
return await subscribe(
220220
schema=self.schema,
221-
query=request_data.query, # type: ignore
221+
query=request_data.document or request_data.query, # type: ignore
222222
variable_values=request_data.variables,
223223
context_value=context,
224224
root_value=root_value,
@@ -228,7 +228,7 @@ async def execute_operation(
228228

229229
return await execute(
230230
schema=self.schema,
231-
query=request_data.query,
231+
query=request_data.document or request_data.query,
232232
root_value=root_value,
233233
variable_values=request_data.variables,
234234
context_value=context,
@@ -539,10 +539,13 @@ async def parse_multipart_subscriptions(
539539
return self.parse_json(await request.get_body())
540540

541541
async def get_graphql_request_data(
542-
self, data: dict[str, Any], protocol: Literal["http", "multipart-subscription"]
542+
self,
543+
data: dict[str, Any],
544+
protocol: Literal["http", "multipart-subscription", "subscription"],
543545
) -> GraphQLRequestData:
544546
return GraphQLRequestData(
545547
query=data.get("query"),
548+
document=None,
546549
variables=data.get("variables"),
547550
operation_name=data.get("operationName"),
548551
extensions=data.get("extensions"),

src/graphql_server/http/sync_base_view.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def execute_operation(
116116

117117
return execute_sync(
118118
schema=self.schema,
119-
query=request_data.query,
119+
query=request_data.document or request_data.query,
120120
root_value=root_value,
121121
variable_values=request_data.variables,
122122
context_value=context,
@@ -139,6 +139,7 @@ def get_graphql_request_data(
139139
) -> GraphQLRequestData:
140140
return GraphQLRequestData(
141141
query=data.get("query"),
142+
document=None,
142143
variables=data.get("variables"),
143144
operation_name=data.get("operationName"),
144145
extensions=data.get("extensions"),

src/graphql_server/subscriptions/protocols/graphql_transport_ws/handlers.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from graphql_server import execute, subscribe
2020
from graphql_server.exceptions import ConnectionRejectionError, GraphQLValidationError
21+
from graphql_server.http import GraphQLRequestData
2122
from graphql_server.http.exceptions import (
2223
NonJsonMessageReceived,
2324
NonTextMessageReceived,
@@ -253,13 +254,15 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
253254
message["payload"].get("variables"),
254255
)
255256

257+
request_data = await self.view.get_graphql_request_data(
258+
message["payload"], "subscription"
259+
)
260+
256261
operation = Operation(
257262
self,
258263
message["id"],
259264
operation_type,
260-
message["payload"]["query"],
261-
message["payload"].get("variables"),
262-
message["payload"].get("operationName"),
265+
request_data,
263266
)
264267

265268
operation.task = asyncio.create_task(self.run_operation(operation))
@@ -274,7 +277,8 @@ async def run_operation(self, operation: Operation[Context, RootValue]) -> None:
274277
if operation.operation_type == OperationType.SUBSCRIPTION:
275278
result_source = await subscribe(
276279
schema=self.schema,
277-
query=operation.query,
280+
query=operation.request_data.document
281+
or operation.request_data.query,
278282
variable_values=operation.variables,
279283
operation_name=operation.operation_name,
280284
context_value=self.context,
@@ -284,7 +288,8 @@ async def run_operation(self, operation: Operation[Context, RootValue]) -> None:
284288
else:
285289
result_source = await execute(
286290
schema=self.schema,
287-
query=operation.query,
291+
query=operation.request_data.document
292+
or operation.request_data.query,
288293
variable_values=operation.variables,
289294
context_value=self.context,
290295
root_value=self.root_value,
@@ -378,31 +383,37 @@ class Operation(Generic[Context, RootValue]):
378383
"completed",
379384
"handler",
380385
"id",
381-
"operation_name",
382386
"operation_type",
383-
"query",
387+
"request_data",
384388
"task",
385-
"variables",
386389
]
387390

388391
def __init__(
389392
self,
390393
handler: BaseGraphQLTransportWSHandler[Context, RootValue],
391394
id: str,
392395
operation_type: OperationType,
393-
query: str,
394-
variables: Optional[dict[str, object]],
395-
operation_name: Optional[str],
396+
request_data: GraphQLRequestData,
396397
) -> None:
397398
self.handler = handler
398399
self.id = id
399400
self.operation_type = operation_type
400-
self.query = query
401-
self.variables = variables
402-
self.operation_name = operation_name
401+
self.request_data = request_data
403402
self.completed = False
404403
self.task: Optional[asyncio.Task] = None
405404

405+
@property
406+
def query(self) -> Optional[str]:
407+
return self.request_data.query
408+
409+
@property
410+
def variables(self) -> Optional[dict[str, Any]]:
411+
return self.request_data.variables
412+
413+
@property
414+
def operation_name(self) -> Optional[str]:
415+
return self.request_data.operation_name
416+
406417
async def send_operation_message(self, message: Message) -> None:
407418
if self.completed:
408419
return

0 commit comments

Comments
 (0)