Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions src/fastcs/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import Future
from types import MethodType

from softioc.asyncio_dispatcher import AsyncioDispatcher
Expand All @@ -21,7 +20,7 @@ def __init__(
self._controller = controller

self._initial_tasks = [controller.connect]
self._scan_tasks: list[Future] = []
self._scan_tasks: list[asyncio.Task] = []

asyncio.run_coroutine_threadsafe(
self._controller.initialise(), self._loop
Expand All @@ -41,22 +40,31 @@ def _link_process_tasks(self):
_link_single_controller_put_tasks(single_mapping)
_link_attribute_sender_class(single_mapping)

def __del__(self):
self.stop_scan_tasks()

def run(self):
self._run_initial_tasks()
self._start_scan_tasks()

self.start_scan_tasks()
self._run()

def _run_initial_tasks(self):
for task in self._initial_tasks:
future = asyncio.run_coroutine_threadsafe(task(), self._loop)
future.result()

def _start_scan_tasks(self):
scan_tasks = _get_scan_tasks(self._mapping)
def start_scan_tasks(self):
self._scan_tasks = [
self._loop.create_task(coro()) for coro in _get_scan_coros(self._mapping)
]

for task in scan_tasks:
asyncio.run_coroutine_threadsafe(task(), self._loop)
def stop_scan_tasks(self):
for task in self._scan_tasks:
if not task.done():
try:
task.cancel()
except asyncio.CancelledError:
pass

def _run(self):
raise NotImplementedError("Specific Backend must implement _run")
Expand Down Expand Up @@ -98,15 +106,15 @@ async def callback(value):
return callback


def _get_scan_tasks(mapping: Mapping) -> list[Callable]:
def _get_scan_coros(mapping: Mapping) -> list[Callable]:
scan_dict: dict[float, list[Callable]] = defaultdict(list)

for single_mapping in mapping.get_controller_mappings():
_add_scan_method_tasks(scan_dict, single_mapping)
_add_attribute_updater_tasks(scan_dict, single_mapping)

scan_tasks = _get_periodic_scan_tasks(scan_dict)
return scan_tasks
scan_coros = _get_periodic_scan_coros(scan_dict)
return scan_coros


def _add_scan_method_tasks(
Expand Down Expand Up @@ -144,18 +152,18 @@ async def callback():
return callback


def _get_periodic_scan_tasks(scan_dict: dict[float, list[Callable]]) -> list[Callable]:
periodic_scan_tasks: list[Callable] = []
def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]:
periodic_scan_coros: list[Callable] = []
for period, methods in scan_dict.items():
periodic_scan_tasks.append(_create_periodic_scan_task(period, methods))
periodic_scan_coros.append(_create_periodic_scan_coro(period, methods))

return periodic_scan_tasks
return periodic_scan_coros


def _create_periodic_scan_task(period, methods: list[Callable]) -> Callable:
async def scan_task() -> None:
def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable:
async def scan_coro() -> None:
while True:
await asyncio.gather(*[method() for method in methods])
await asyncio.sleep(period)

return scan_task
return scan_coro
8 changes: 5 additions & 3 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from time import sleep
import asyncio

import pytest

Expand All @@ -16,7 +16,7 @@ async def init_task(self):
self.init_task_called = True

def _run(self):
pass
asyncio.run_coroutine_threadsafe(asyncio.sleep(0.3), self._loop)


@pytest.mark.asyncio
Expand All @@ -41,5 +41,7 @@ async def test_backend(controller):
# Scan tasks should be running
for _ in range(3):
count = controller.count
sleep(0.05)
await asyncio.sleep(0.1)
assert controller.count > count

backend.stop_scan_tasks()