Skip to content

Commit 1a7f797

Browse files
authored
GraphQL: Improve schema error handling & tidy (#6026)
1 parent 095fa7e commit 1a7f797

File tree

3 files changed

+115
-100
lines changed

3 files changed

+115
-100
lines changed

cylc/flow/network/graphql.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
"""
2121

2222
from functools import partial
23+
from inspect import isclass, iscoroutinefunction
2324
import logging
2425
from typing import TYPE_CHECKING, Any, Tuple, Union
2526

26-
from inspect import isclass, iscoroutinefunction
27-
2827
from graphene.utils.str_converters import to_snake_case
2928
from graphql.execution.utils import (
3029
get_operation_root_type, get_field_def
@@ -35,16 +34,16 @@
3534
from graphql.backend.base import GraphQLBackend, GraphQLDocument
3635
from graphql.backend.core import execute_and_validate
3736
from graphql.utils.base import type_from_ast
38-
from graphql.type import get_named_type
37+
from graphql.type.definition import get_named_type
3938
from promise import Promise
4039
from rx import Observable
4140

42-
from cylc.flow.network.schema import NODE_MAP, get_type_str
41+
from cylc.flow.network.schema import NODE_MAP
4342

4443
if TYPE_CHECKING:
4544
from graphql.execution import ExecutionResult
4645
from graphql.language.ast import Document
47-
from graphql.type import GraphQLSchema
46+
from graphql.type.schema import GraphQLSchema
4847

4948

5049
logger = logging.getLogger(__name__)
@@ -376,18 +375,18 @@ def resolve(self, next_, root, info, **args):
376375

377376
# Avoid using the protobuf default if field isn't set.
378377
if (
379-
hasattr(root, 'ListFields')
380-
and hasattr(root, field_name)
381-
and get_type_str(info.return_type) not in NODE_MAP
378+
hasattr(root, 'ListFields')
379+
and hasattr(root, field_name)
380+
and get_named_type(info.return_type).name not in NODE_MAP
382381
):
383382

384383
# Gather fields set in root
385384
parent_path_string = f'{info.path[:-1:]}'
386385
stamp = getattr(root, 'stamp', '')
387386
if (
388-
parent_path_string not in self.field_sets
389-
or self.field_sets[
390-
parent_path_string]['stamp'] != stamp
387+
parent_path_string not in self.field_sets
388+
or self.field_sets[
389+
parent_path_string]['stamp'] != stamp
391390
):
392391
self.field_sets[parent_path_string] = {
393392
'stamp': stamp,
@@ -398,36 +397,33 @@ def resolve(self, next_, root, info, **args):
398397
}
399398

400399
if (
401-
parent_path_string in self.field_sets
402-
and field_name not in self.field_sets[
403-
parent_path_string]['fields']
400+
parent_path_string in self.field_sets
401+
and field_name not in self.field_sets[
402+
parent_path_string]['fields']
404403
):
405404
return None
406405
# Do not resolve subfields of an empty type
407406
# by setting as null in parent/root.
408-
elif (
409-
isinstance(root, dict)
410-
and field_name in root
411-
):
407+
elif isinstance(root, dict) and field_name in root:
412408
field_value = root[field_name]
413409
if (
414-
field_value in EMPTY_VALUES
415-
or (
416-
hasattr(field_value, 'ListFields')
417-
and not field_value.ListFields()
418-
)
410+
field_value in EMPTY_VALUES
411+
or (
412+
hasattr(field_value, 'ListFields')
413+
and not field_value.ListFields()
414+
)
419415
):
420416
return None
421417
if (
422-
info.operation.operation in self.ASYNC_OPS
423-
or iscoroutinefunction(next_)
418+
info.operation.operation in self.ASYNC_OPS
419+
or iscoroutinefunction(next_)
424420
):
425421
return self.async_null_setter(next_, root, info, **args)
426422
return null_setter(next_(root, info, **args))
427423

428424
if (
429-
info.operation.operation in self.ASYNC_OPS
430-
or iscoroutinefunction(next_)
425+
info.operation.operation in self.ASYNC_OPS
426+
or iscoroutinefunction(next_)
431427
):
432428
return self.async_resolve(next_, root, info, **args)
433429
return next_(root, info, **args)

cylc/flow/network/resolvers.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
from time import time
2626
from typing import (
2727
Any,
28+
AsyncGenerator,
2829
Dict,
2930
List,
3031
NamedTuple,
3132
Optional,
3233
Tuple,
3334
TYPE_CHECKING,
3435
Union,
36+
cast,
3537
)
3638
from uuid import uuid4
3739

@@ -58,6 +60,8 @@
5860
from cylc.flow.data_store_mgr import DataStoreMgr
5961
from cylc.flow.scheduler import Scheduler
6062

63+
DeltaQueue = queue.Queue[Tuple[str, str, dict]]
64+
6165

6266
class TaskMsg(NamedTuple):
6367
"""Tuple for Scheduler.message_queue"""
@@ -395,7 +399,7 @@ async def get_nodes_all(self, node_type, args):
395399
[
396400
node
397401
for flow in await self.get_workflows_data(args)
398-
for node in flow.get(node_type).values()
402+
for node in flow[node_type].values()
399403
if node_filter(
400404
node,
401405
node_type,
@@ -538,7 +542,9 @@ async def get_nodes_edges(self, root_nodes, args):
538542
nodes=sort_elements(nodes, args),
539543
edges=sort_elements(edges, args))
540544

541-
async def subscribe_delta(self, root, info, args):
545+
async def subscribe_delta(
546+
self, root, info: 'ResolveInfo', args
547+
) -> AsyncGenerator[Any, None]:
542548
"""Delta subscription async generator.
543549
544550
Async generator mapping the incoming protobuf deltas to
@@ -553,19 +559,19 @@ async def subscribe_delta(self, root, info, args):
553559
self.delta_store[sub_id] = {}
554560

555561
op_id = root
556-
if 'ops_queue' not in info.context:
557-
info.context['ops_queue'] = {}
558-
info.context['ops_queue'][op_id] = queue.Queue()
559-
op_queue = info.context['ops_queue'][op_id]
562+
op_queue: queue.Queue[Tuple[UUID, str]] = queue.Queue()
563+
cast('dict', info.context).setdefault(
564+
'ops_queue', {}
565+
)[op_id] = op_queue
560566
self.delta_processing_flows[sub_id] = set()
561567
delta_processing_flows = self.delta_processing_flows[sub_id]
562568

563569
delta_queues = self.data_store_mgr.delta_queues
564-
deltas_queue = queue.Queue()
570+
deltas_queue: DeltaQueue = queue.Queue()
565571

566-
counters = {}
567-
delta_yield_queue = queue.Queue()
568-
flow_delta_queues = {}
572+
counters: Dict[str, int] = {}
573+
delta_yield_queue: DeltaQueue = queue.Queue()
574+
flow_delta_queues: Dict[str, queue.Queue[Tuple[str, dict]]] = {}
569575
try:
570576
# Iterate over the queue yielding deltas
571577
w_ids = workflow_ids

0 commit comments

Comments
 (0)