Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions src/fastcs/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import Future
from types import MethodType

from softioc.asyncio_dispatcher import AsyncioDispatcher
Expand All @@ -21,7 +20,7 @@
self._controller = controller

self._initial_tasks = [controller.connect]
self._scan_tasks: list[Future] = []
self._scan_tasks: list[asyncio.Task] = []

asyncio.run_coroutine_threadsafe(
self._controller.initialise(), self._loop
Expand All @@ -41,22 +40,31 @@
_link_single_controller_put_tasks(single_mapping)
_link_attribute_sender_class(single_mapping)

def __del__(self):
self.stop_scan_tasks()

Check warning on line 44 in src/fastcs/backend.py

View check run for this annotation

Codecov / codecov/patch

src/fastcs/backend.py#L44

Added line #L44 was not covered by tests

def run(self):
self._run_initial_tasks()
self._start_scan_tasks()

self.start_scan_tasks()
self._run()

def _run_initial_tasks(self):
for task in self._initial_tasks:
future = asyncio.run_coroutine_threadsafe(task(), self._loop)
future.result()

def _start_scan_tasks(self):
scan_tasks = _get_scan_tasks(self._mapping)
def start_scan_tasks(self):
self._scan_tasks = [
self._loop.create_task(coro()) for coro in _get_scan_coros(self._mapping)
]

for task in scan_tasks:
asyncio.run_coroutine_threadsafe(task(), self._loop)
def stop_scan_tasks(self):
for task in self._scan_tasks:
if not task.done():
try:
task.cancel()
except asyncio.CancelledError:
pass

Check warning on line 67 in src/fastcs/backend.py

View check run for this annotation

Codecov / codecov/patch

src/fastcs/backend.py#L66-L67

Added lines #L66 - L67 were not covered by tests

def _run(self):
raise NotImplementedError("Specific Backend must implement _run")
Expand Down Expand Up @@ -98,15 +106,15 @@
return callback


def _get_scan_tasks(mapping: Mapping) -> list[Callable]:
def _get_scan_coros(mapping: Mapping) -> list[Callable]:
scan_dict: dict[float, list[Callable]] = defaultdict(list)

for single_mapping in mapping.get_controller_mappings():
_add_scan_method_tasks(scan_dict, single_mapping)
_add_attribute_updater_tasks(scan_dict, single_mapping)

scan_tasks = _get_periodic_scan_tasks(scan_dict)
return scan_tasks
scan_coros = _get_periodic_scan_coros(scan_dict)
return scan_coros


def _add_scan_method_tasks(
Expand Down Expand Up @@ -144,18 +152,18 @@
return callback


def _get_periodic_scan_tasks(scan_dict: dict[float, list[Callable]]) -> list[Callable]:
periodic_scan_tasks: list[Callable] = []
def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]:
periodic_scan_coros: list[Callable] = []
for period, methods in scan_dict.items():
periodic_scan_tasks.append(_create_periodic_scan_task(period, methods))
periodic_scan_coros.append(_create_periodic_scan_coro(period, methods))

return periodic_scan_tasks
return periodic_scan_coros


def _create_periodic_scan_task(period, methods: list[Callable]) -> Callable:
async def scan_task() -> None:
def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable:
async def scan_coro() -> None:
while True:
await asyncio.gather(*[method() for method in methods])
await asyncio.sleep(period)

return scan_task
return scan_coro
8 changes: 5 additions & 3 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from time import sleep
import asyncio

import pytest

Expand All @@ -16,7 +16,7 @@ async def init_task(self):
self.init_task_called = True

def _run(self):
pass
asyncio.run_coroutine_threadsafe(asyncio.sleep(0.3), self._loop)


@pytest.mark.asyncio
Expand All @@ -41,5 +41,7 @@ async def test_backend(controller):
# Scan tasks should be running
for _ in range(3):
count = controller.count
sleep(0.05)
await asyncio.sleep(0.1)
assert controller.count > count

backend.stop_scan_tasks()
Loading