diff --git a/src/fastcs/backend.py b/src/fastcs/backend.py index 80c8eab67..6e9448627 100644 --- a/src/fastcs/backend.py +++ b/src/fastcs/backend.py @@ -1,7 +1,6 @@ import asyncio from collections import defaultdict from collections.abc import Callable -from concurrent.futures import Future from types import MethodType from softioc.asyncio_dispatcher import AsyncioDispatcher @@ -21,7 +20,7 @@ def __init__( self._controller = controller self._initial_tasks = [controller.connect] - self._scan_tasks: list[Future] = [] + self._scan_tasks: list[asyncio.Task] = [] asyncio.run_coroutine_threadsafe( self._controller.initialise(), self._loop @@ -41,10 +40,12 @@ def _link_process_tasks(self): _link_single_controller_put_tasks(single_mapping) _link_attribute_sender_class(single_mapping) + def __del__(self): + self.stop_scan_tasks() + def run(self): self._run_initial_tasks() - self._start_scan_tasks() - + self.start_scan_tasks() self._run() def _run_initial_tasks(self): @@ -52,11 +53,18 @@ def _run_initial_tasks(self): future = asyncio.run_coroutine_threadsafe(task(), self._loop) future.result() - def _start_scan_tasks(self): - scan_tasks = _get_scan_tasks(self._mapping) + def start_scan_tasks(self): + self._scan_tasks = [ + self._loop.create_task(coro()) for coro in _get_scan_coros(self._mapping) + ] - for task in scan_tasks: - asyncio.run_coroutine_threadsafe(task(), self._loop) + def stop_scan_tasks(self): + for task in self._scan_tasks: + if not task.done(): + try: + task.cancel() + except asyncio.CancelledError: + pass def _run(self): raise NotImplementedError("Specific Backend must implement _run") @@ -98,15 +106,15 @@ async def callback(value): return callback -def _get_scan_tasks(mapping: Mapping) -> list[Callable]: +def _get_scan_coros(mapping: Mapping) -> list[Callable]: scan_dict: dict[float, list[Callable]] = defaultdict(list) for single_mapping in mapping.get_controller_mappings(): _add_scan_method_tasks(scan_dict, single_mapping) _add_attribute_updater_tasks(scan_dict, single_mapping) - scan_tasks = _get_periodic_scan_tasks(scan_dict) - return scan_tasks + scan_coros = _get_periodic_scan_coros(scan_dict) + return scan_coros def _add_scan_method_tasks( @@ -144,18 +152,18 @@ async def callback(): return callback -def _get_periodic_scan_tasks(scan_dict: dict[float, list[Callable]]) -> list[Callable]: - periodic_scan_tasks: list[Callable] = [] +def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]: + periodic_scan_coros: list[Callable] = [] for period, methods in scan_dict.items(): - periodic_scan_tasks.append(_create_periodic_scan_task(period, methods)) + periodic_scan_coros.append(_create_periodic_scan_coro(period, methods)) - return periodic_scan_tasks + return periodic_scan_coros -def _create_periodic_scan_task(period, methods: list[Callable]) -> Callable: - async def scan_task() -> None: +def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable: + async def scan_coro() -> None: while True: await asyncio.gather(*[method() for method in methods]) await asyncio.sleep(period) - return scan_task + return scan_coro diff --git a/tests/test_backend.py b/tests/test_backend.py index 67ea12fe9..ae518059f 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,4 +1,4 @@ -from time import sleep +import asyncio import pytest @@ -16,7 +16,7 @@ async def init_task(self): self.init_task_called = True def _run(self): - pass + asyncio.run_coroutine_threadsafe(asyncio.sleep(0.3), self._loop) @pytest.mark.asyncio @@ -41,5 +41,7 @@ async def test_backend(controller): # Scan tasks should be running for _ in range(3): count = controller.count - sleep(0.05) + await asyncio.sleep(0.1) assert controller.count > count + + backend.stop_scan_tasks()