diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index e1a9e7d2b5..8e57fb5db5 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Transport tasks for calculation jobs.""" +import asyncio import functools import logging import tempfile @@ -406,7 +407,10 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override else: logger.warning(f'killed CalcJob<{node.pk}> but async future was None') raise - except (plumpy.process_states.Interruption, plumpy.futures.CancelledError): + except (plumpy.futures.CancelledError, asyncio.CancelledError): + node.set_process_status(f'Transport task {command} was cancelled') + raise + except plumpy.process_states.Interruption: node.set_process_status(f'Transport task {command} was interrupted') raise else: diff --git a/aiida/engine/transports.py b/aiida/engine/transports.py index b722140834..d301235e27 100644 --- a/aiida/engine/transports.py +++ b/aiida/engine/transports.py @@ -108,6 +108,11 @@ def do_open(): try: transport_request.count += 1 yield transport_request.future + except asyncio.CancelledError: # pylint: disable=try-except-raise + # note this is only required in python<=3.7, + # where asyncio.CancelledError inherits from Exception + _LOGGER.debug('Transport task cancelled') + raise except Exception: _LOGGER.error('Exception whilst using transport:\n%s', traceback.format_exc()) raise diff --git a/aiida/manage/external/rmq.py b/aiida/manage/external/rmq.py index f2069603ab..c7cccfd149 100644 --- a/aiida/manage/external/rmq.py +++ b/aiida/manage/external/rmq.py @@ -9,6 +9,7 @@ ########################################################################### # pylint: disable=cyclic-import """Components to communicate tasks to RabbitMQ.""" +import asyncio from collections.abc import Mapping import logging import traceback @@ -209,6 +210,10 @@ async def _continue(self, communicator, pid, nowait, tag=None): message = 'the class of the process could not be imported.' self.handle_continue_exception(node, exception, message) raise + except asyncio.CancelledError: # pylint: disable=try-except-raise + # note this is only required in python<=3.7, + # where asyncio.CancelledError inherits from Exception + raise except Exception as exception: message = 'failed to recreate the process instance in order to continue it.' self.handle_continue_exception(node, exception, message) diff --git a/environment.yml b/environment.yml index 09b03959f8..1f5d72c388 100644 --- a/environment.yml +++ b/environment.yml @@ -21,11 +21,11 @@ dependencies: - ipython~=7.20 - jinja2~=2.10 - jsonschema~=3.0 -- kiwipy[rmq]~=0.7.3 +- kiwipy[rmq]~=0.7.4 - numpy~=1.17 - pamqp~=2.3 - paramiko>=2.7.2,~=2.7 -- plumpy~=0.18.6 +- plumpy~=0.19.0 - pgsu~=0.1.0 - psutil~=5.6 - psycopg2>=2.8.3,~=2.8 diff --git a/requirements/requirements-py-3.7.txt b/requirements/requirements-py-3.7.txt index 49c846ca68..b1decfa255 100644 --- a/requirements/requirements-py-3.7.txt +++ b/requirements/requirements-py-3.7.txt @@ -57,7 +57,7 @@ jupyter-console==6.2.0 jupyter-core==4.7.1 jupyterlab-pygments==0.1.2 jupyterlab-widgets==1.0.0 -kiwipy==0.7.3 +kiwipy==0.7.4 kiwisolver==1.3.1 Mako==1.1.4 MarkupSafe==1.1.1 @@ -88,7 +88,7 @@ pickleshare==0.7.5 Pillow==8.1.0 plotly==4.14.3 pluggy==0.13.1 -plumpy==0.18.6 +plumpy==0.19.0 prometheus-client==0.9.0 prompt-toolkit==3.0.14 psutil==5.8.0 diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt index 8ab1f0cd09..4d2326794d 100644 --- a/requirements/requirements-py-3.8.txt +++ b/requirements/requirements-py-3.8.txt @@ -56,7 +56,7 @@ jupyter-console==6.2.0 jupyter-core==4.7.1 jupyterlab-pygments==0.1.2 jupyterlab-widgets==1.0.0 -kiwipy==0.7.3 +kiwipy==0.7.4 kiwisolver==1.3.1 Mako==1.1.4 MarkupSafe==1.1.1 @@ -87,7 +87,7 @@ pickleshare==0.7.5 Pillow==8.1.0 plotly==4.14.3 pluggy==0.13.1 -plumpy==0.18.6 +plumpy==0.19.0 prometheus-client==0.9.0 prompt-toolkit==3.0.14 psutil==5.8.0 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index f050094ca4..5bcba80782 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -56,7 +56,7 @@ jupyter-console==6.2.0 jupyter-core==4.7.1 jupyterlab-pygments==0.1.2 jupyterlab-widgets==1.0.0 -kiwipy==0.7.3 +kiwipy==0.7.4 kiwisolver==1.3.1 Mako==1.1.4 MarkupSafe==1.1.1 @@ -87,7 +87,7 @@ pickleshare==0.7.5 Pillow==8.1.0 plotly==4.14.3 pluggy==0.13.1 -plumpy==0.18.6 +plumpy==0.19.0 prometheus-client==0.9.0 prompt-toolkit==3.0.14 psutil==5.8.0 diff --git a/setup.json b/setup.json index 9dac042c00..1c62482f47 100644 --- a/setup.json +++ b/setup.json @@ -35,11 +35,11 @@ "ipython~=7.20", "jinja2~=2.10", "jsonschema~=3.0", - "kiwipy[rmq]~=0.7.3", + "kiwipy[rmq]~=0.7.4", "numpy~=1.17", "pamqp~=2.3", "paramiko~=2.7,>=2.7.2", - "plumpy~=0.18.6", + "plumpy~=0.19.0", "pgsu~=0.1.0", "psutil~=5.6", "psycopg2-binary~=2.8,>=2.8.3", diff --git a/tests/engine/test_daemon.py b/tests/engine/test_daemon.py index fd9c64ff7b..53f6b4a20b 100644 --- a/tests/engine/test_daemon.py +++ b/tests/engine/test_daemon.py @@ -8,8 +8,35 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test daemon module.""" -from aiida.backends.testbase import AiidaTestCase +import asyncio +from plumpy.process_states import ProcessState +import pytest -class TestDaemon(AiidaTestCase): - """Testing the daemon.""" +from aiida.manage.manager import get_manager +from tests.utils import processes as test_processes + + +async def reach_waiting_state(process): + while process.state != ProcessState.WAITING: + await asyncio.sleep(0.1) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_cancel_process_task(): + """This test is designed to replicate how processes are cancelled in the current `shutdown_runner` callback. + + The `CancelledError` should bubble up to the caller, and not be caught and transition the process to excepted. + """ + runner = get_manager().get_runner() + # create the process and start it running + process = runner.instantiate_process(test_processes.WaitProcess) + task = runner.loop.create_task(process.step_until_terminated()) + # wait for the process to reach a WAITING state + runner.loop.run_until_complete(asyncio.wait_for(reach_waiting_state(process), 5.0)) + # cancel the task and wait for the cancellation + task.cancel() + with pytest.raises(asyncio.CancelledError): + runner.loop.run_until_complete(asyncio.wait_for(task, 5.0)) + # the node should still record a waiting state, not excepted + assert process.node.process_state == ProcessState.WAITING