Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions cylc/uiserver/data_store_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@
from .workflows_mgr import workflow_request

if TYPE_CHECKING:
from logging import Logger
from cylc.flow.data_messages_pb2 import PbWorkflow
from cylc.uiserver.workflows_mgr import WorkflowsManager


def log_call(fcn):
Expand Down Expand Up @@ -99,15 +101,26 @@ class DataStoreMgr:
RECONCILE_TIMEOUT = 5. # seconds
PENDING_DELTA_CHECK_INTERVAL = 0.5

def __init__(self, workflows_mgr, log, max_threads=10):
def __init__(
self,
workflows_mgr: 'WorkflowsManager',
log: 'Logger',
max_threads: int = 10
):
self.workflows_mgr = workflows_mgr
self.log = log
self.data = {}
self.data: Dict[str, dict] = {}
self.w_subs: Dict[str, WorkflowSubscriber] = {}
self.topics = {ALL_DELTAS.encode('utf-8'), b'shutdown'}
self.loop = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self.executor = ThreadPoolExecutor(max_threads)
self.delta_queues = {}
self.delta_queues: Dict[str, dict] = {}

@property
def loop(self) -> asyncio.AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_running_loop()
return self._loop

@log_call
async def register_workflow(self, w_id: str, is_active: bool) -> None:
Expand Down Expand Up @@ -155,8 +168,8 @@ async def connect_workflow(self, w_id, contact_data):
blocking the main loop.

"""
if self.loop is None:
self.loop = asyncio.get_running_loop()
# get event loop:
self.loop

# don't sync if subscription exists
if w_id in self.w_subs:
Expand Down
24 changes: 17 additions & 7 deletions cylc/uiserver/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@
from cylc.uiserver.websockets import authenticated as websockets_authenticated

if TYPE_CHECKING:
from cylc.uiserver.resolvers import Resolvers
from cylc.uiserver.websockets.tornado import TornadoSubscriptionServer
from graphene import Schema
from graphql.execution import ExecutionResult
from jupyter_server.auth.identity import User as JPSUser
from cylc.uiserver.resolvers import Resolvers
from cylc.uiserver.websockets.tornado import TornadoSubscriptionServer


ME = getpass.getuser()
Expand Down Expand Up @@ -300,11 +301,20 @@ class UIServerGraphQLHandler(CylcAppHandler, TornadoGraphQLHandler):
def set_default_headers(self) -> None:
self.set_header('Server', '')

def initialize(self, schema=None, executor=None, middleware=None,
root_value=None, graphiql=False, pretty=False,
batch=False, backend=None, auth=None, **kwargs):
super(TornadoGraphQLHandler, self).initialize()
self.auth = auth
def initialize( # type: ignore[override]
self,
schema: 'Schema',
executor=None,
middleware=None,
root_value=None,
graphiql: bool = False,
pretty: bool = False,
batch: bool = False,
backend=None,
auth=None,
**kwargs
):
CylcAppHandler.initialize(self, auth)
self.schema = schema

if middleware is not None:
Expand Down
113 changes: 78 additions & 35 deletions cylc/uiserver/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
from graphql.language.base import print_ast
import psutil

from cylc.flow.data_store_mgr import WORKFLOW
from cylc.flow.exceptions import (
ClientError,
ClientTimeout,
ServiceFileError,
WorkflowFilesError,
)
from cylc.flow.id import Tokens
from cylc.flow.network.resolvers import BaseResolvers
from cylc.flow.network.schema import GenericResponse
from cylc.flow.scripts.clean import CleanOptions
from cylc.flow.scripts.clean import run

Expand Down Expand Up @@ -205,20 +207,14 @@ class Services:
"""Cylc services provided by the UI Server."""

@staticmethod
def _error(message: Union[Exception, str]):
def _error(message: Union[Exception, str]) -> Tuple[bool, str]:
"""Format error case response."""
return [
False,
str(message)
]
return (False, str(message))

@staticmethod
def _return(message: str):
def _return(message: str) -> Tuple[bool, str]:
"""Format success case response."""
return [
True,
message
]
return (True, message)

@classmethod
async def clean(
Expand All @@ -228,7 +224,7 @@ async def clean(
workflows_mgr: 'WorkflowsManager',
executor: 'Executor',
log: 'Logger'
):
) -> Tuple[bool, str]:
"""Calls `cylc clean`"""
# Convert Schema options → cylc.flow.workflow_files.init_clean opts:
opts = _schema_opts_to_api_opts(args, schema=CleanOptions)
Expand Down Expand Up @@ -275,7 +271,7 @@ async def play(
args: Dict[str, Any],
workflows_mgr: 'WorkflowsManager',
log: 'Logger',
) -> List[Union[bool, str]]:
) -> Tuple[bool, str]:
"""Calls `cylc play`."""
cylc_version = args.pop('cylc_version', None)
results: Dict[str, str] = {}
Expand Down Expand Up @@ -489,44 +485,29 @@ async def mutator(
w_args: Dict[str, Any],
_kwargs: Dict[str, Any],
_meta: Dict[str, Any]
) -> List[Dict[str, Any]]:
) -> List[GenericResponse]:
"""Mutate workflow."""
req_meta = {
'auth_user': info.context.get( # type: ignore[union-attr]
'current_user', 'unknown user'
)
}
w_ids = [
flow[WORKFLOW].id
for flow in await self.get_workflows_data(w_args)]
w_ids = await self.get_workflow_ids(w_args)
if not w_ids:
return [{
'response': (False, 'No matching workflows')}]
return [self._no_matching_workflows_response]
# Pass the request to the workflow GraphQL endpoints
_, variables, _, _ = info.context.get( # type: ignore[union-attr]
'graphql_params'
)
# Create a modified request string,
# containing only the current mutation/field.
operation_ast = deepcopy(info.operation)
operation_ast.selection_set.selections = info.field_asts

graphql_args = {
'request_string': print_ast(operation_ast),
'variables': variables,
}
return await self.workflows_mgr.multi_request( # type: ignore # TODO
graphql_args = _get_graphql_args(info)
results = await self.workflows_mgr.multi_request(
'graphql', w_ids, graphql_args, req_meta=req_meta
)
return process_graphql_multi_request_results(command, results)

async def service(
self,
info: 'ResolveInfo',
command: str,
workflows: Iterable['Tokens'],
kwargs: Dict[str, Any],
) -> List[Union[bool, str]]:

) -> Optional[Tuple[bool, str]]:
if command == 'clean': # noqa: SIM116
return await Services.clean(
workflows,
Expand Down Expand Up @@ -636,3 +617,65 @@ async def stream_log(
file
):
yield item


def _get_graphql_args(info: 'ResolveInfo') -> Dict[str, Any]:
"""Helper function for mutator."""
_, variables, _, _ = info.context.get( # type: ignore[union-attr]
'graphql_params'
)
# Create a modified request string,
# containing only the current mutation/field.
operation_ast = deepcopy(info.operation)
operation_ast.selection_set.selections = info.field_asts

return {
'request_string': print_ast(operation_ast),
'variables': variables,
}


def process_graphql_multi_request_results(
command: str,
results: List[Union[object, bytes, Exception]]
) -> List[GenericResponse]:
"""Wrap multi_request results as list of GenericResponses, suitable for
GraphQL mutator return value.

Extracts result[data][command][results] from Scheduler GraphQL mutator
response to avoid double-nesting when passing it to UIServer GraphQL
mutator.

Args:
command: The GraphQL mutation name.
results: Return value of WorkflowsManager.multi_request().

N.B. WorkflowsManager.multi_request() can be used for non-GraphQL requests,
so this processing is in a separate function.
"""
if not results:
return [
GenericResponse(
success=False, message="No matching workflows running"
)
]
ret: List[GenericResponse] = []
for result in results:
if isinstance(result, (ClientTimeout, ClientError)):
# "Expected" error
ret.append(
GenericResponse(
workflowId=result.workflow,
success=False,
message=str(result)
)
)
continue
if isinstance(result, Exception):
# Unexpected error
raise result
if not isinstance(result, dict) or not result.get('data'):
raise ValueError(f"Unexpected response: {result!r}")
mutation_result: dict = result['data'][command]['results'][0]
ret.append(GenericResponse(**mutation_result))
return ret
9 changes: 5 additions & 4 deletions cylc/uiserver/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from cylc.flow.network.schema import (
NODE_MAP,
CyclePoint,
GenericResponse,
SortArgs,
Task,
Job,
Expand Down Expand Up @@ -77,8 +76,10 @@ async def mutator(
resolvers: 'Resolvers' = (
info.context.get('resolvers') # type: ignore[union-attr]
)
res = await resolvers.service(info, command, parsed_workflows, kwargs)
return GenericResponse(result=res)
res = await resolvers.service(command, parsed_workflows, kwargs)
return info.return_type.graphene_type( # type: ignore[union-attr]
result=res # TODO: results
)


class RunMode(graphene.Enum):
Expand Down Expand Up @@ -108,7 +109,7 @@ class CylcVersion(graphene.String):
"""A Cylc version identifier e.g. 8.0.0"""


class Play(graphene.Mutation):
class Play(graphene.Mutation): # TODO: inherit from cylc.flow.network.schema.WorkflowsMutation?
class Meta:
description = sstrip('''
Start, resume or restart a workflow run.
Expand Down
43 changes: 31 additions & 12 deletions cylc/uiserver/tests/test_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from textwrap import dedent

import pytest
from tornado.httpclient import HTTPClientError

from cylc.flow.id import Tokens

Expand Down Expand Up @@ -47,12 +48,21 @@ async def _fetch(*endpoint, query=None, headers=None):
**headers,
'Content-Type': 'application/json'
}
return await jp_fetch(
*endpoint,
method='POST',
headers=headers,
body=json.dumps({'query': query}, indent=4),
)
try:
return await jp_fetch(
*endpoint,
method='POST',
headers=headers,
body=json.dumps({'query': query}, indent=4),
)
except HTTPClientError as exc:
# debug info
msg = f"{type(exc).__name__}: {exc}"
if exc.response:
body = json.loads(exc.response.body)
for err in body.get('errors', []):
msg += f"\n\n{err}"
raise Exception(msg)

return _fetch

Expand Down Expand Up @@ -122,13 +132,13 @@ async def _log(*args, **kwargs):
query='''
mutation {
hold(workflows: ["*"], tasks: []) {
result
results { workflowId, success }
}
pause(workflows: ["*"]) {
result
results { workflowId, success }
}
stop(workflows: ["*"]) {
result
results { workflowId, success }
}
}
''',
Expand All @@ -141,23 +151,32 @@ async def _log(*args, **kwargs):
assert calls[0][0][2]['request_string'] == dedent('''
mutation {
hold(workflows: ["*"], tasks: []) {
result
results {
workflowId
success
}
}
}
''').strip()
# the second for the pause mutation
assert calls[1][0][2]['request_string'] == dedent('''
mutation {
pause(workflows: ["*"]) {
result
results {
workflowId
success
}
}
}
''').strip()
# the third for the stop mutation
assert calls[2][0][2]['request_string'] == dedent('''
mutation {
stop(workflows: ["*"]) {
result
results {
workflowId
success
}
}
}
''').strip()
Loading