diff --git a/cylc/uiserver/data_store_mgr.py b/cylc/uiserver/data_store_mgr.py index 870de40a..99ff1715 100644 --- a/cylc/uiserver/data_store_mgr.py +++ b/cylc/uiserver/data_store_mgr.py @@ -57,7 +57,9 @@ from .workflows_mgr import workflow_request if TYPE_CHECKING: + from logging import Logger from cylc.flow.data_messages_pb2 import PbWorkflow + from cylc.uiserver.workflows_mgr import WorkflowsManager def log_call(fcn): @@ -99,15 +101,26 @@ class DataStoreMgr: RECONCILE_TIMEOUT = 5. # seconds PENDING_DELTA_CHECK_INTERVAL = 0.5 - def __init__(self, workflows_mgr, log, max_threads=10): + def __init__( + self, + workflows_mgr: 'WorkflowsManager', + log: 'Logger', + max_threads: int = 10 + ): self.workflows_mgr = workflows_mgr self.log = log - self.data = {} + self.data: Dict[str, dict] = {} self.w_subs: Dict[str, WorkflowSubscriber] = {} self.topics = {ALL_DELTAS.encode('utf-8'), b'shutdown'} - self.loop = None + self._loop: Optional[asyncio.AbstractEventLoop] = None self.executor = ThreadPoolExecutor(max_threads) - self.delta_queues = {} + self.delta_queues: Dict[str, dict] = {} + + @property + def loop(self) -> asyncio.AbstractEventLoop: + if self._loop is None: + self._loop = asyncio.get_running_loop() + return self._loop @log_call async def register_workflow(self, w_id: str, is_active: bool) -> None: @@ -155,8 +168,8 @@ async def connect_workflow(self, w_id, contact_data): blocking the main loop. """ - if self.loop is None: - self.loop = asyncio.get_running_loop() + # get event loop: + self.loop # don't sync if subscription exists if w_id in self.w_subs: diff --git a/cylc/uiserver/handlers.py b/cylc/uiserver/handlers.py index c4e50fcc..de6df64a 100644 --- a/cylc/uiserver/handlers.py +++ b/cylc/uiserver/handlers.py @@ -37,10 +37,11 @@ from cylc.uiserver.websockets import authenticated as websockets_authenticated if TYPE_CHECKING: - from cylc.uiserver.resolvers import Resolvers - from cylc.uiserver.websockets.tornado import TornadoSubscriptionServer + from graphene import Schema from graphql.execution import ExecutionResult from jupyter_server.auth.identity import User as JPSUser + from cylc.uiserver.resolvers import Resolvers + from cylc.uiserver.websockets.tornado import TornadoSubscriptionServer ME = getpass.getuser() @@ -300,11 +301,20 @@ class UIServerGraphQLHandler(CylcAppHandler, TornadoGraphQLHandler): def set_default_headers(self) -> None: self.set_header('Server', '') - def initialize(self, schema=None, executor=None, middleware=None, - root_value=None, graphiql=False, pretty=False, - batch=False, backend=None, auth=None, **kwargs): - super(TornadoGraphQLHandler, self).initialize() - self.auth = auth + def initialize( # type: ignore[override] + self, + schema: 'Schema', + executor=None, + middleware=None, + root_value=None, + graphiql: bool = False, + pretty: bool = False, + batch: bool = False, + backend=None, + auth=None, + **kwargs + ): + CylcAppHandler.initialize(self, auth) self.schema = schema if middleware is not None: diff --git a/cylc/uiserver/resolvers.py b/cylc/uiserver/resolvers.py index db3e357f..5abc9c44 100644 --- a/cylc/uiserver/resolvers.py +++ b/cylc/uiserver/resolvers.py @@ -39,13 +39,15 @@ from graphql.language.base import print_ast import psutil -from cylc.flow.data_store_mgr import WORKFLOW from cylc.flow.exceptions import ( + ClientError, + ClientTimeout, ServiceFileError, WorkflowFilesError, ) from cylc.flow.id import Tokens from cylc.flow.network.resolvers import BaseResolvers +from cylc.flow.network.schema import GenericResponse from cylc.flow.scripts.clean import CleanOptions from cylc.flow.scripts.clean import run @@ -205,20 +207,14 @@ class Services: """Cylc services provided by the UI Server.""" @staticmethod - def _error(message: Union[Exception, str]): + def _error(message: Union[Exception, str]) -> Tuple[bool, str]: """Format error case response.""" - return [ - False, - str(message) - ] + return (False, str(message)) @staticmethod - def _return(message: str): + def _return(message: str) -> Tuple[bool, str]: """Format success case response.""" - return [ - True, - message - ] + return (True, message) @classmethod async def clean( @@ -228,7 +224,7 @@ async def clean( workflows_mgr: 'WorkflowsManager', executor: 'Executor', log: 'Logger' - ): + ) -> Tuple[bool, str]: """Calls `cylc clean`""" # Convert Schema options → cylc.flow.workflow_files.init_clean opts: opts = _schema_opts_to_api_opts(args, schema=CleanOptions) @@ -275,7 +271,7 @@ async def play( args: Dict[str, Any], workflows_mgr: 'WorkflowsManager', log: 'Logger', - ) -> List[Union[bool, str]]: + ) -> Tuple[bool, str]: """Calls `cylc play`.""" cylc_version = args.pop('cylc_version', None) results: Dict[str, str] = {} @@ -489,44 +485,29 @@ async def mutator( w_args: Dict[str, Any], _kwargs: Dict[str, Any], _meta: Dict[str, Any] - ) -> List[Dict[str, Any]]: + ) -> List[GenericResponse]: """Mutate workflow.""" req_meta = { 'auth_user': info.context.get( # type: ignore[union-attr] 'current_user', 'unknown user' ) } - w_ids = [ - flow[WORKFLOW].id - for flow in await self.get_workflows_data(w_args)] + w_ids = await self.get_workflow_ids(w_args) if not w_ids: - return [{ - 'response': (False, 'No matching workflows')}] + return [self._no_matching_workflows_response] # Pass the request to the workflow GraphQL endpoints - _, variables, _, _ = info.context.get( # type: ignore[union-attr] - 'graphql_params' - ) - # Create a modified request string, - # containing only the current mutation/field. - operation_ast = deepcopy(info.operation) - operation_ast.selection_set.selections = info.field_asts - - graphql_args = { - 'request_string': print_ast(operation_ast), - 'variables': variables, - } - return await self.workflows_mgr.multi_request( # type: ignore # TODO + graphql_args = _get_graphql_args(info) + results = await self.workflows_mgr.multi_request( 'graphql', w_ids, graphql_args, req_meta=req_meta ) + return process_graphql_multi_request_results(command, results) async def service( self, - info: 'ResolveInfo', command: str, workflows: Iterable['Tokens'], kwargs: Dict[str, Any], - ) -> List[Union[bool, str]]: - + ) -> Optional[Tuple[bool, str]]: if command == 'clean': # noqa: SIM116 return await Services.clean( workflows, @@ -636,3 +617,65 @@ async def stream_log( file ): yield item + + +def _get_graphql_args(info: 'ResolveInfo') -> Dict[str, Any]: + """Helper function for mutator.""" + _, variables, _, _ = info.context.get( # type: ignore[union-attr] + 'graphql_params' + ) + # Create a modified request string, + # containing only the current mutation/field. + operation_ast = deepcopy(info.operation) + operation_ast.selection_set.selections = info.field_asts + + return { + 'request_string': print_ast(operation_ast), + 'variables': variables, + } + + +def process_graphql_multi_request_results( + command: str, + results: List[Union[object, bytes, Exception]] +) -> List[GenericResponse]: + """Wrap multi_request results as list of GenericResponses, suitable for + GraphQL mutator return value. + + Extracts result[data][command][results] from Scheduler GraphQL mutator + response to avoid double-nesting when passing it to UIServer GraphQL + mutator. + + Args: + command: The GraphQL mutation name. + results: Return value of WorkflowsManager.multi_request(). + + N.B. WorkflowsManager.multi_request() can be used for non-GraphQL requests, + so this processing is in a separate function. + """ + if not results: + return [ + GenericResponse( + success=False, message="No matching workflows running" + ) + ] + ret: List[GenericResponse] = [] + for result in results: + if isinstance(result, (ClientTimeout, ClientError)): + # "Expected" error + ret.append( + GenericResponse( + workflowId=result.workflow, + success=False, + message=str(result) + ) + ) + continue + if isinstance(result, Exception): + # Unexpected error + raise result + if not isinstance(result, dict) or not result.get('data'): + raise ValueError(f"Unexpected response: {result!r}") + mutation_result: dict = result['data'][command]['results'][0] + ret.append(GenericResponse(**mutation_result)) + return ret diff --git a/cylc/uiserver/schema.py b/cylc/uiserver/schema.py index 99933a74..b7bca30a 100644 --- a/cylc/uiserver/schema.py +++ b/cylc/uiserver/schema.py @@ -33,7 +33,6 @@ from cylc.flow.network.schema import ( NODE_MAP, CyclePoint, - GenericResponse, SortArgs, Task, Job, @@ -77,8 +76,10 @@ async def mutator( resolvers: 'Resolvers' = ( info.context.get('resolvers') # type: ignore[union-attr] ) - res = await resolvers.service(info, command, parsed_workflows, kwargs) - return GenericResponse(result=res) + res = await resolvers.service(command, parsed_workflows, kwargs) + return info.return_type.graphene_type( # type: ignore[union-attr] + result=res # TODO: results + ) class RunMode(graphene.Enum): @@ -108,7 +109,7 @@ class CylcVersion(graphene.String): """A Cylc version identifier e.g. 8.0.0""" -class Play(graphene.Mutation): +class Play(graphene.Mutation): # TODO: inherit from cylc.flow.network.schema.WorkflowsMutation? class Meta: description = sstrip(''' Start, resume or restart a workflow run. diff --git a/cylc/uiserver/tests/test_graphql.py b/cylc/uiserver/tests/test_graphql.py index d0b26d0b..256cb7a1 100644 --- a/cylc/uiserver/tests/test_graphql.py +++ b/cylc/uiserver/tests/test_graphql.py @@ -17,6 +17,7 @@ from textwrap import dedent import pytest +from tornado.httpclient import HTTPClientError from cylc.flow.id import Tokens @@ -47,12 +48,21 @@ async def _fetch(*endpoint, query=None, headers=None): **headers, 'Content-Type': 'application/json' } - return await jp_fetch( - *endpoint, - method='POST', - headers=headers, - body=json.dumps({'query': query}, indent=4), - ) + try: + return await jp_fetch( + *endpoint, + method='POST', + headers=headers, + body=json.dumps({'query': query}, indent=4), + ) + except HTTPClientError as exc: + # debug info + msg = f"{type(exc).__name__}: {exc}" + if exc.response: + body = json.loads(exc.response.body) + for err in body.get('errors', []): + msg += f"\n\n{err}" + raise Exception(msg) return _fetch @@ -122,13 +132,13 @@ async def _log(*args, **kwargs): query=''' mutation { hold(workflows: ["*"], tasks: []) { - result + results { workflowId, success } } pause(workflows: ["*"]) { - result + results { workflowId, success } } stop(workflows: ["*"]) { - result + results { workflowId, success } } } ''', @@ -141,7 +151,10 @@ async def _log(*args, **kwargs): assert calls[0][0][2]['request_string'] == dedent(''' mutation { hold(workflows: ["*"], tasks: []) { - result + results { + workflowId + success + } } } ''').strip() @@ -149,7 +162,10 @@ async def _log(*args, **kwargs): assert calls[1][0][2]['request_string'] == dedent(''' mutation { pause(workflows: ["*"]) { - result + results { + workflowId + success + } } } ''').strip() @@ -157,7 +173,10 @@ async def _log(*args, **kwargs): assert calls[2][0][2]['request_string'] == dedent(''' mutation { stop(workflows: ["*"]) { - result + results { + workflowId + success + } } } ''').strip() diff --git a/cylc/uiserver/tests/test_resolvers.py b/cylc/uiserver/tests/test_resolvers.py index 0ed28726..7ea7d1ee 100644 --- a/cylc/uiserver/tests/test_resolvers.py +++ b/cylc/uiserver/tests/test_resolvers.py @@ -41,13 +41,12 @@ def test__schema_opts_to_api_opts(schema_opts, schema, expect): @pytest.mark.parametrize( 'func, message, expect', [ - (services._return, 'Hello.', [True, 'Hello.']), - (services._error, 'Goodbye.', [False, 'Goodbye.']) + (services._return, 'Hello.', (True, 'Hello.')), + (services._error, 'Goodbye.', (False, 'Goodbye.')) ] ) def test_Services_anciliary_methods(func, message, expect): - """Small functions return [bool, message]. - """ + """Small functions return (bool, message).""" assert func(message) == expect @@ -58,7 +57,7 @@ def test_Services_anciliary_methods(func, message, expect): [Tokens('wflow1'), Tokens('~murray/wflow2')], {}, {}, - [True, "Workflow(s) started"], + (True, "Workflow(s) started"), {}, id="multiple" ), @@ -66,7 +65,7 @@ def test_Services_anciliary_methods(func, message, expect): [Tokens('~feynman/wflow1')], {}, {}, - [False, "Cannot start workflows for other users."], + (False, "Cannot start workflows for other users."), {}, id="other user's wflow" ), @@ -74,7 +73,7 @@ def test_Services_anciliary_methods(func, message, expect): [Tokens('wflow1')], {'cylc_version': 'top'}, {'CYLC_VERSION': 'bottom', 'CYLC_ENV_NAME': 'quark'}, - [True, "Workflow(s) started"], + (True, "Workflow(s) started"), {'CYLC_VERSION': 'top'}, id="cylc version overrides env" ), @@ -82,7 +81,7 @@ def test_Services_anciliary_methods(func, message, expect): [Tokens('wflow1')], {}, {'CYLC_VERSION': 'charm', 'CYLC_ENV_NAME': 'quark'}, - [True, "Workflow(s) started"], + (True, "Workflow(s) started"), {'CYLC_VERSION': 'charm', 'CYLC_ENV_NAME': 'quark'}, id="cylc env not overriden if no version specified" ), @@ -93,7 +92,7 @@ async def test_play( workflows: List[Tokens], args: Dict[str, Any], env: Dict[str, str], - expected_ret: list, + expected_ret: Tuple[bool, str], expected_env: Dict[str, str], ): """It runs cylc play correctly. @@ -223,9 +222,9 @@ def wait(timeout): workflows_mgr=Mock(spec=WorkflowsManager), log=Mock(), ) - assert ret == [ + assert ret == ( False, "Command 'cylc play wflow1' timed out after 20 seconds" - ] + ) async def test_cat_log(workflow_run_dir): diff --git a/cylc/uiserver/workflows_mgr.py b/cylc/uiserver/workflows_mgr.py index e441c388..21f5afba 100644 --- a/cylc/uiserver/workflows_mgr.py +++ b/cylc/uiserver/workflows_mgr.py @@ -95,6 +95,7 @@ async def workflow_request( else: print(msg, file=sys.stderr) print(exc, file=sys.stderr) + exc.workflow = client.workflow raise exc