diff --git a/src/fastcs/backend.py b/src/fastcs/backend.py index 6e9448627..10f7bc3e8 100644 --- a/src/fastcs/backend.py +++ b/src/fastcs/backend.py @@ -1,6 +1,7 @@ import asyncio from collections import defaultdict from collections.abc import Callable +from concurrent.futures import Future from types import MethodType from softioc.asyncio_dispatcher import AsyncioDispatcher @@ -19,8 +20,8 @@ def __init__( self._loop = self._dispatcher.loop self._controller = controller - self._initial_tasks = [controller.connect] - self._scan_tasks: list[asyncio.Task] = [] + self._initial_coros = [controller.connect] + self._scan_futures: set[Future] = set() asyncio.run_coroutine_threadsafe( self._controller.initialise(), self._loop @@ -41,28 +42,29 @@ def _link_process_tasks(self): _link_attribute_sender_class(single_mapping) def __del__(self): - self.stop_scan_tasks() + self.stop_scan_futures() def run(self): - self._run_initial_tasks() - self.start_scan_tasks() + self._run_initial_futures() + self.start_scan_futures() self._run() - def _run_initial_tasks(self): - for task in self._initial_tasks: - future = asyncio.run_coroutine_threadsafe(task(), self._loop) + def _run_initial_futures(self): + for coro in self._initial_coros: + future = asyncio.run_coroutine_threadsafe(coro(), self._loop) future.result() - def start_scan_tasks(self): - self._scan_tasks = [ - self._loop.create_task(coro()) for coro in _get_scan_coros(self._mapping) - ] + def start_scan_futures(self): + self._scan_futures = { + asyncio.run_coroutine_threadsafe(coro(), self._loop) + for coro in _get_scan_coros(self._mapping) + } - def stop_scan_tasks(self): - for task in self._scan_tasks: - if not task.done(): + def stop_scan_futures(self): + for future in self._scan_futures: + if not future.done(): try: - task.cancel() + future.cancel() except asyncio.CancelledError: pass diff --git a/tests/test_backend.py b/tests/test_backend.py index ae518059f..c9f4f8e61 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -10,7 +10,7 @@ def __init__(self, controller): super().__init__(controller) self.init_task_called = False - self._initial_tasks.append(self.init_task) + self._initial_coros.append(self.init_task) async def init_task(self): self.init_task_called = True @@ -44,4 +44,4 @@ async def test_backend(controller): await asyncio.sleep(0.1) assert controller.count > count - backend.stop_scan_tasks() + backend.stop_scan_futures()