diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index bacd3f867..dce98bb81 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -14,7 +14,7 @@ class Attribute(Generic[T, AttributeIORefT]): """Base FastCS attribute. - Instances of this class added to a ``Controller`` will be used by the backend. + Instances of this class added to a ``Controller`` will be used by the FastCS class. """ def __init__( diff --git a/src/fastcs/backend.py b/src/fastcs/backend.py deleted file mode 100644 index 788510f99..000000000 --- a/src/fastcs/backend.py +++ /dev/null @@ -1,198 +0,0 @@ -import asyncio -from collections import defaultdict -from collections.abc import Callable, Coroutine - -from fastcs.attribute_io_ref import AttributeIORef -from fastcs.cs_methods import Command, Put, Scan -from fastcs.datatypes import T - -from .attributes import ONCE, AttrR, AttrW -from .controller import BaseController, Controller -from .controller_api import ControllerAPI -from .exceptions import FastCSError -from .util import validate_hinted_attributes - - -class Backend: - """For keeping track of tasks during FastCS serving.""" - - def __init__( - self, - controller: Controller, - loop: asyncio.AbstractEventLoop, - ): - self._controller = controller - self._loop = loop - - self._initial_coros = [controller.connect] - self._scan_tasks: set[asyncio.Task] = set() - - # Initialise controller and then build its APIs - loop.run_until_complete(controller.initialise()) - loop.run_until_complete(controller.attribute_initialise()) - validate_hinted_attributes(controller) - self.controller_api = build_controller_api(controller) - self._link_process_tasks() - - def _link_process_tasks(self): - for controller_api in self.controller_api.walk_api(): - _link_put_tasks(controller_api) - - def __del__(self): - self._stop_scan_tasks() - - async def serve(self): - scans, initials = _get_scan_and_initial_coros(self.controller_api) - self._initial_coros += initials - await self._run_initial_coros() - await self._start_scan_tasks(scans) - - async def _run_initial_coros(self): - for coro in self._initial_coros: - await coro() - - async def _start_scan_tasks( - self, coros: list[Callable[[], Coroutine[None, None, None]]] - ): - self._scan_tasks = {self._loop.create_task(coro()) for coro in coros} - - for task in self._scan_tasks: - task.add_done_callback(self._scan_done) - - def _scan_done(self, task: asyncio.Task): - try: - task.result() - except Exception as e: - raise FastCSError( - "Exception raised in scan method of " - f"{self._controller.__class__.__name__}" - ) from e - - def _stop_scan_tasks(self): - for task in self._scan_tasks: - if not task.done(): - try: - task.cancel() - except asyncio.CancelledError: - pass - - -def _link_put_tasks(controller_api: ControllerAPI) -> None: - for name, method in controller_api.put_methods.items(): - name = name.removeprefix("put_") - - attribute = controller_api.attributes[name] - match attribute: - case AttrW(): - attribute.add_process_callback(method.fn) - case _: - raise FastCSError( - f"Attribute type {type(attribute)} does not " - f"support put operations for {name}" - ) - - -def _get_scan_and_initial_coros( - root_controller_api: ControllerAPI, -) -> tuple[list[Callable], list[Callable]]: - scan_dict: dict[float, list[Callable]] = defaultdict(list) - initial_coros: list[Callable] = [] - - for controller_api in root_controller_api.walk_api(): - _add_scan_method_tasks(scan_dict, controller_api) - _add_attribute_updater_tasks(scan_dict, initial_coros, controller_api) - - scan_coros = _get_periodic_scan_coros(scan_dict) - return scan_coros, initial_coros - - -def _add_scan_method_tasks( - scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI -): - for method in controller_api.scan_methods.values(): - scan_dict[method.period].append(method.fn) - - -def _add_attribute_updater_tasks( - scan_dict: dict[float, list[Callable]], - initial_coros: list[Callable], - controller_api: ControllerAPI, -): - for attribute in controller_api.attributes.values(): - match attribute: - case ( - AttrR(_io_ref=AttributeIORef(update_period=update_period)) as attribute - ): - callback = _create_updater_callback(attribute) - if update_period is ONCE: - initial_coros.append(callback) - elif update_period is not None: - scan_dict[update_period].append(callback) - - -def _create_updater_callback(attribute: AttrR[T]): - async def callback(): - try: - await attribute.update() - except Exception as e: - print(f"Update loop in {attribute} stopped:\n{e.__class__.__name__}: {e}") - raise - - return callback - - -def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]: - periodic_scan_coros: list[Callable] = [] - for period, methods in scan_dict.items(): - periodic_scan_coros.append(_create_periodic_scan_coro(period, methods)) - - return periodic_scan_coros - - -def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable: - async def _sleep(): - await asyncio.sleep(period) - - methods.append(_sleep) # Create periodic behavior - - async def scan_coro() -> None: - while True: - await asyncio.gather(*[method() for method in methods]) - - return scan_coro - - -def build_controller_api(controller: Controller) -> ControllerAPI: - """Build a `ControllerAPI` for a `BaseController` and its sub controllers""" - return _build_controller_api(controller, []) - - -def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI: - """Build a `ControllerAPI` for a `BaseController` and its sub controllers""" - scan_methods: dict[str, Scan] = {} - put_methods: dict[str, Put] = {} - command_methods: dict[str, Command] = {} - for attr_name in dir(controller): - attr = getattr(controller, attr_name) - match attr: - case Put(enabled=True): - put_methods[attr_name] = attr - case Scan(enabled=True): - scan_methods[attr_name] = attr - case Command(enabled=True): - command_methods[attr_name] = attr - case _: - pass - - return ControllerAPI( - path=path, - attributes=controller.attributes, - command_methods=command_methods, - put_methods=put_methods, - scan_methods=scan_methods, - sub_apis={ - name: _build_controller_api(sub_controller, path + [name]) - for name, sub_controller in controller.get_sub_controllers().items() - }, - description=controller.description, - ) diff --git a/src/fastcs/launch.py b/src/fastcs/launch.py index d35092519..88cc9cb8d 100644 --- a/src/fastcs/launch.py +++ b/src/fastcs/launch.py @@ -2,7 +2,8 @@ import inspect import json import signal -from collections.abc import Coroutine, Sequence +from collections import defaultdict +from collections.abc import Callable, Coroutine, Sequence from functools import partial from pathlib import Path from typing import Annotated, Any, Optional, TypeAlias, get_type_hints @@ -13,16 +14,21 @@ from ruamel.yaml import YAML from fastcs import __version__ +from fastcs.attribute_io_ref import AttributeIORef from fastcs.transport.epics.ca.transport import EpicsCATransport from fastcs.transport.epics.pva.transport import EpicsPVATransport from fastcs.transport.graphql.transport import GraphQLTransport from fastcs.transport.rest.transport import RestTransport from fastcs.transport.tango.transport import TangoTransport -from .backend import Backend -from .controller import Controller -from .exceptions import LaunchError +from .attributes import ONCE, AttrR, AttrW +from .controller import BaseController, Controller +from .controller_api import ControllerAPI +from .cs_methods import Command, Put, Scan +from .datatypes import T +from .exceptions import FastCSError, LaunchError from .transport import Transport +from .util import validate_hinted_attributes # Define a type alias for transport options TransportList: TypeAlias = list[ @@ -35,17 +41,31 @@ class FastCS: - """For launching a controller with given transport(s).""" - - def __init__(self, controller: Controller, transports: Sequence[Transport]): - self._loop = asyncio.get_event_loop() + """For launching a controller with given transport(s) and keeping + track of tasks during serving.""" + + def __init__( + self, + controller: Controller, + transports: Sequence[Transport], + loop: asyncio.AbstractEventLoop | None = None, + ): + self._loop = loop or asyncio.get_event_loop() self._controller = controller - self._backend = Backend(controller, self._loop) + + self._initial_coros = [controller.connect] + self._scan_tasks: set[asyncio.Task] = set() + + # these initialise the controller & build its APIs + self._loop.run_until_complete(controller.initialise()) + self._loop.run_until_complete(controller.attribute_initialise()) + validate_hinted_attributes(controller) + self.controller_api = build_controller_api(controller) + self._link_process_tasks() + self._transports = transports for transport in self._transports: - transport.initialise( - controller_api=self._backend.controller_api, loop=self._loop - ) + transport.initialise(controller_api=self.controller_api, loop=self._loop) def create_docs(self) -> None: for transport in self._transports: @@ -62,11 +82,54 @@ def run(self): self._loop.add_signal_handler(signal.SIGTERM, serve.cancel) self._loop.run_until_complete(serve) + def _link_process_tasks(self): + for controller_api in self.controller_api.walk_api(): + _link_put_tasks(controller_api) + + def __del__(self): + self._stop_scan_tasks() + + async def serve_routines(self): + scans, initials = _get_scan_and_initial_coros(self.controller_api) + self._initial_coros += initials + await self._run_initial_coros() + await self._start_scan_tasks(scans) + + async def _run_initial_coros(self): + for coro in self._initial_coros: + await coro() + + async def _start_scan_tasks( + self, coros: list[Callable[[], Coroutine[None, None, None]]] + ): + self._scan_tasks = {self._loop.create_task(coro()) for coro in coros} + + for task in self._scan_tasks: + task.add_done_callback(self._scan_done) + + def _scan_done(self, task: asyncio.Task): + try: + task.result() + except Exception as e: + raise FastCSError( + "Exception raised in scan method of " + f"{self._controller.__class__.__name__}" + ) from e + + def _stop_scan_tasks(self): + for task in self._scan_tasks: + if not task.done(): + try: + task.cancel() + except asyncio.CancelledError: + pass + async def serve(self) -> None: - coros = [self._backend.serve()] + coros = [self.serve_routines()] + context = { "controller": self._controller, - "controller_api": self._backend.controller_api, + "controller_api": self.controller_api, "transports": [ transport.__class__.__name__ for transport in self._transports ], @@ -118,6 +181,125 @@ async def interactive_shell( await stop_event.wait() +def _link_put_tasks(controller_api: ControllerAPI) -> None: + for name, method in controller_api.put_methods.items(): + name = name.removeprefix("put_") + + attribute = controller_api.attributes[name] + match attribute: + case AttrW(): + attribute.add_process_callback(method.fn) + case _: + raise FastCSError( + f"Attribute type {type(attribute)} does not" + f"support put operations for {name}" + ) + + +def _get_scan_and_initial_coros( + root_controller_api: ControllerAPI, +) -> tuple[list[Callable], list[Callable]]: + scan_dict: dict[float, list[Callable]] = defaultdict(list) + initial_coros: list[Callable] = [] + + for controller_api in root_controller_api.walk_api(): + _add_scan_method_tasks(scan_dict, controller_api) + _add_attribute_updater_tasks(scan_dict, initial_coros, controller_api) + + scan_coros = _get_periodic_scan_coros(scan_dict) + return scan_coros, initial_coros + + +def _add_scan_method_tasks( + scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI +): + for method in controller_api.scan_methods.values(): + scan_dict[method.period].append(method.fn) + + +def _add_attribute_updater_tasks( + scan_dict: dict[float, list[Callable]], + initial_coros: list[Callable], + controller_api: ControllerAPI, +): + for attribute in controller_api.attributes.values(): + match attribute: + case ( + AttrR(_io_ref=AttributeIORef(update_period=update_period)) as attribute + ): + callback = _create_updater_callback(attribute) + if update_period is ONCE: + initial_coros.append(callback) + elif update_period is not None: + scan_dict[update_period].append(callback) + + +def _create_updater_callback(attribute: AttrR[T]): + async def callback(): + try: + await attribute.update() + except Exception as e: + print(f"Update loop in {attribute} stopped:\n{e.__class__.__name__}: {e}") + raise + + return callback + + +def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]: + periodic_scan_coros: list[Callable] = [] + for period, methods in scan_dict.items(): + periodic_scan_coros.append(_create_periodic_scan_coro(period, methods)) + + return periodic_scan_coros + + +def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable: + async def _sleep(): + await asyncio.sleep(period) + + methods.append(_sleep) # Create periodic behavior + + async def scan_coro() -> None: + while True: + await asyncio.gather(*[method() for method in methods]) + + return scan_coro + + +def build_controller_api(controller: Controller) -> ControllerAPI: + return _build_controller_api(controller, []) + + +def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI: + scan_methods: dict[str, Scan] = {} + put_methods: dict[str, Put] = {} + command_methods: dict[str, Command] = {} + for attr_name in dir(controller): + attr = getattr(controller, attr_name) + match attr: + case Put(enabled=True): + put_methods[attr_name] = attr + case Scan(enabled=True): + scan_methods[attr_name] = attr + case Command(enabled=True): + command_methods[attr_name] = attr + case _: + pass + + return ControllerAPI( + path=path, + attributes=controller.attributes, + scan_methods=scan_methods, + put_methods=put_methods, + command_methods=command_methods, + sub_apis={ + name: _build_controller_api(sub_controller, path + [name]) + for name, sub_controller in controller.get_sub_controllers().items() + }, + description=controller.description, + ) + + def launch( controller_class: type[Controller], version: str | None = None, @@ -218,6 +400,7 @@ def run( instance = FastCS( controller, instance_options.transport, + loop=asyncio.get_event_loop(), ) instance.create_gui() diff --git a/tests/assertable_controller.py b/tests/assertable_controller.py index 0ad09754f..d42c91559 100644 --- a/tests/assertable_controller.py +++ b/tests/assertable_controller.py @@ -8,10 +8,10 @@ from fastcs.attribute_io import AttributeIO from fastcs.attribute_io_ref import AttributeIORef from fastcs.attributes import AttrR, AttrW -from fastcs.backend import build_controller_api from fastcs.controller import Controller from fastcs.controller_api import ControllerAPI from fastcs.datatypes import Int, T +from fastcs.launch import build_controller_api from fastcs.wrappers import command, scan diff --git a/tests/benchmarking/controller.py b/tests/benchmarking/controller.py index aaad9bd12..5642a38c1 100644 --- a/tests/benchmarking/controller.py +++ b/tests/benchmarking/controller.py @@ -1,3 +1,5 @@ +import asyncio + from fastcs import FastCS from fastcs.attributes import AttrR, AttrW from fastcs.controller import Controller @@ -23,10 +25,7 @@ def run(): ), TangoTransport(dsr=TangoDSROptions(dev_name="MY/BENCHMARK/DEVICE")), ] - instance = FastCS( - MyTestController(), - transport_options, - ) + instance = FastCS(MyTestController(), transport_options, asyncio.get_event_loop()) instance.run() diff --git a/tests/conftest.py b/tests/conftest.py index 09903b4a8..64db42d23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,8 +18,8 @@ from softioc import builder from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.backend import build_controller_api from fastcs.datatypes import Bool, Float, Int, String +from fastcs.launch import build_controller_api from fastcs.transport.tango.dsr import register_dev from tests.assertable_controller import MyTestAttributeIORef, MyTestController from tests.example_p4p_ioc import run as _run_p4p_ioc diff --git a/tests/test_backend.py b/tests/test_backend.py deleted file mode 100644 index 4413086b3..000000000 --- a/tests/test_backend.py +++ /dev/null @@ -1,174 +0,0 @@ -import asyncio -from dataclasses import dataclass - -from fastcs.attribute_io import AttributeIO -from fastcs.attribute_io_ref import AttributeIORef -from fastcs.attributes import ONCE, AttrR, AttrRW -from fastcs.backend import Backend, build_controller_api -from fastcs.controller import Controller -from fastcs.cs_methods import Command -from fastcs.datatypes import Int -from fastcs.exceptions import FastCSError -from fastcs.wrappers import command, scan - - -def test_backend(controller): - loop = asyncio.get_event_loop() - backend = Backend(controller, loop) - - # Controller should be initialised by Backend and not connected - assert controller.initialised - assert not controller.connected - - # Controller Attributes with a Sender should have a _process_callback created - assert controller.read_write_int.has_process_callback() - - async def test_wrapper(): - loop.create_task(backend.serve()) - await asyncio.sleep(0) # Yield to task - - # Controller should have been connected by Backend - assert controller.connected - - # Scan tasks should be running - for _ in range(3): - count = controller.count - await asyncio.sleep(0.01) - assert controller.count > count - backend._stop_scan_tasks() - - loop.run_until_complete(test_wrapper()) - - -def test_controller_api(): - class MyTestController(Controller): - attr1: AttrRW[int] = AttrRW(Int()) - - def __init__(self): - super().__init__(description="Controller for testing") - - self.attributes["attr2"] = AttrRW(Int()) - - @command() - async def do_nothing(self): - pass - - @scan(1.0) - async def scan_nothing(self): - pass - - controller = MyTestController() - api = build_controller_api(controller) - - assert api.description == controller.description - assert list(api.attributes) == ["attr1", "attr2"] - assert list(api.command_methods) == ["do_nothing"] - assert list(api.scan_methods) == ["scan_nothing"] - - -def test_controller_api_methods(): - class MyTestController(Controller): - def __init__(self): - super().__init__() - - async def initialise(self): - async def do_nothing_dynamic() -> None: - pass - - self.do_nothing_dynamic = Command(do_nothing_dynamic) - - @command() - async def do_nothing_static(self): - pass - - controller = MyTestController() - loop = asyncio.get_event_loop() - backend = Backend(controller, loop) - - async def test_wrapper(): - await controller.do_nothing_static() - await controller.do_nothing_dynamic() - - await backend.controller_api.command_methods["do_nothing_static"]() - await backend.controller_api.command_methods["do_nothing_dynamic"]() - - loop.run_until_complete(test_wrapper()) - - -def test_update_periods(): - @dataclass - class AttributeIORefTimesCalled(AttributeIORef): - update_period: float | None = None - _times_called = 0 - - class AttributeIOTimesCalled(AttributeIO[int, AttributeIORefTimesCalled]): - async def update(self, attr: AttrR[int, AttributeIORefTimesCalled]): - attr.io_ref._times_called += 1 - await attr.set(attr.io_ref._times_called) - - class MyController(Controller): - update_once = AttrR(Int(), io_ref=AttributeIORefTimesCalled(update_period=ONCE)) - update_quickly = AttrR( - Int(), io_ref=AttributeIORefTimesCalled(update_period=0.1) - ) - update_never = AttrR( - Int(), io_ref=AttributeIORefTimesCalled(update_period=None) - ) - - controller = MyController(ios=[AttributeIOTimesCalled()]) - loop = asyncio.get_event_loop() - - backend = Backend(controller, loop) - - assert controller.update_quickly.get() == 0 - assert controller.update_once.get() == 0 - assert controller.update_never.get() == 0 - - async def test_wrapper(): - loop.create_task(backend.serve()) - await asyncio.sleep(1) - - loop.run_until_complete(test_wrapper()) - assert controller.update_quickly.get() > 1 - assert controller.update_once.get() == 1 - assert controller.update_never.get() == 0 - - assert len(backend._scan_tasks) == 1 - assert len(backend._initial_coros) == 2 - - -def test_scan_raises_exception_via_callback(): - class MyTestController(Controller): - def __init__(self): - super().__init__() - - @scan(0.1) - async def raise_exception(self): - raise ValueError("Scan Exception") - - controller = MyTestController() - loop = asyncio.get_event_loop() - backend = Backend(controller, loop) - - exception_info = {} - # This will intercept the exception raised in _scan_done - loop.set_exception_handler( - lambda _loop, context: exception_info.update( - {"exception": context.get("exception")} - ) - ) - - async def test_scan_wrapper(): - await backend.serve() - # This allows scan time to run - await asyncio.sleep(0.2) - # _scan_done should raise an exception - assert isinstance(exception_info["exception"], FastCSError) - for task in backend._scan_tasks: - internal_exception = task.exception() - assert internal_exception - # The task exception comes from scan method raise_exception - assert isinstance(internal_exception, ValueError) - assert "Scan Exception" == str(internal_exception) - - loop.run_until_complete(test_scan_wrapper()) diff --git a/tests/test_launch.py b/tests/test_launch.py index c64ff72af..811e4ed40 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -1,3 +1,4 @@ +import asyncio import json import os from dataclasses import dataclass @@ -9,11 +10,22 @@ from typer.testing import CliRunner from fastcs import __version__ -from fastcs.attributes import AttrR +from fastcs.attribute_io import AttributeIO +from fastcs.attribute_io_ref import AttributeIORef +from fastcs.attributes import ONCE, AttrR, AttrRW from fastcs.controller import Controller +from fastcs.cs_methods import Command from fastcs.datatypes import Int -from fastcs.exceptions import LaunchError -from fastcs.launch import TransportList, _launch, get_controller_schema, launch +from fastcs.exceptions import FastCSError, LaunchError +from fastcs.launch import ( + FastCS, + TransportList, + _launch, + build_controller_api, + get_controller_schema, + launch, +) +from fastcs.wrappers import command, scan @dataclass @@ -159,3 +171,169 @@ def test_error_if_identical_context_in_transports(mocker: MockerFixture, data): result = runner.invoke(app, ["run", str(data / "config.yaml")]) assert isinstance(result.exception, RuntimeError) assert "Duplicate context keys found" in result.exception.args[0] + + +def test_fastcs(controller): + loop = asyncio.get_event_loop() + transport_options = [] + fastcs = FastCS(controller, transport_options, loop) + + # Controller should be initialised by FastCS and not connected + assert controller.initialised + assert not controller.connected + + # Controller Attributes with a Sender should have a _process_callback created + assert controller.read_write_int.has_process_callback() + + async def test_wrapper(): + loop.create_task(fastcs.serve_routines()) + await asyncio.sleep(0) # Yield to task + + # Controller should have been connected by 'Backend' Logic + assert controller.connected + + # Scan tasks should be running + for _ in range(3): + count = controller.count + await asyncio.sleep(0.01) + assert controller.count > count + fastcs._stop_scan_tasks() + + loop.run_until_complete(test_wrapper()) + + +def test_controller_api(): + class MyTestController(Controller): + attr1: AttrRW[int] = AttrRW(Int()) + + def __init__(self): + super().__init__(description="Controller for testing") + + self.attributes["attr2"] = AttrRW(Int()) + + @command() + async def do_nothing(self): + pass + + @scan(1.0) + async def scan_nothing(self): + pass + + controller = MyTestController() + api = build_controller_api(controller) + + assert api.description == controller.description + assert list(api.attributes) == ["attr1", "attr2"] + assert list(api.command_methods) == ["do_nothing"] + assert list(api.scan_methods) == ["scan_nothing"] + + +def test_controller_api_methods(): + class MyTestController(Controller): + def __init__(self): + super().__init__() + + async def initialise(self): + async def do_nothing_dynamic() -> None: + pass + + self.do_nothing_dynamic = Command(do_nothing_dynamic) + + @command() + async def do_nothing_static(self): + pass + + controller = MyTestController() + loop = asyncio.get_event_loop() + transport_options = [] + fastcs = FastCS(controller, transport_options, loop) + + async def test_wrapper(): + await controller.do_nothing_static() + await controller.do_nothing_dynamic() + + await fastcs.controller_api.command_methods["do_nothing_static"]() + await fastcs.controller_api.command_methods["do_nothing_dynamic"]() + + loop.run_until_complete(test_wrapper()) + + +def test_update_periods(): + @dataclass + class AttributeIORefTimesCalled(AttributeIORef): + update_period: float | None = None + _times_called = 0 + + class AttributeIOTimesCalled(AttributeIO[int, AttributeIORefTimesCalled]): + async def update(self, attr: AttrR[int, AttributeIORefTimesCalled]): + attr.io_ref._times_called += 1 + await attr.set(attr.io_ref._times_called) + + class MyController(Controller): + update_once = AttrR(Int(), io_ref=AttributeIORefTimesCalled(update_period=ONCE)) + update_quickly = AttrR( + Int(), io_ref=AttributeIORefTimesCalled(update_period=0.1) + ) + update_never = AttrR( + Int(), io_ref=AttributeIORefTimesCalled(update_period=None) + ) + + controller = MyController(ios=[AttributeIOTimesCalled()]) + loop = asyncio.get_event_loop() + transport_options = [] + + fastcs = FastCS(controller, transport_options, loop) + + assert controller.update_quickly.get() == 0 + assert controller.update_once.get() == 0 + assert controller.update_never.get() == 0 + + async def test_wrapper(): + loop.create_task(fastcs.serve_routines()) + await asyncio.sleep(1) + + loop.run_until_complete(test_wrapper()) + assert controller.update_quickly.get() > 1 + assert controller.update_once.get() == 1 + assert controller.update_never.get() == 0 + + assert len(fastcs._scan_tasks) == 1 + assert len(fastcs._initial_coros) == 2 + + +def test_scan_raises_exception_via_callback(): + class MyTestController(Controller): + def __init__(self): + super().__init__() + + @scan(0.1) + async def raise_exception(self): + raise ValueError("Scan Exception") + + controller = MyTestController() + loop = asyncio.get_event_loop() + transport_options = [] + fastcs = FastCS(controller, transport_options, loop) + + exception_info = {} + # This will intercept the exception raised in _scan_done + loop.set_exception_handler( + lambda _loop, context: exception_info.update( + {"exception": context.get("exception")} + ) + ) + + async def test_scan_wrapper(): + await fastcs.serve_routines() + # This allows scan time to run + await asyncio.sleep(0.2) + # _scan_done should raise an exception + assert isinstance(exception_info["exception"], FastCSError) + for task in fastcs._scan_tasks: + internal_exception = task.exception() + assert internal_exception + # The task exception comes from scan method raise_exception + assert isinstance(internal_exception, ValueError) + assert "Scan Exception" == str(internal_exception) + + loop.run_until_complete(test_scan_wrapper()) diff --git a/tests/test_util.py b/tests/test_util.py index fd44fee7f..e4ea0af63 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -7,9 +7,9 @@ from pydantic import ValidationError from fastcs.attributes import AttrR, AttrRW -from fastcs.backend import Backend from fastcs.controller import Controller from fastcs.datatypes import Bool, Enum, Float, Int, String +from fastcs.launch import FastCS from fastcs.util import numpy_to_fastcs_datatype, snake_to_pascal @@ -66,6 +66,7 @@ def test_numpy_to_fastcs_datatype(numpy_type, fastcs_datatype): def test_hinted_attributes_verified(): loop = asyncio.get_event_loop() + transport_options = [] class ControllerWithWrongType(Controller): hinted_wrong_type: AttrR[int] @@ -75,7 +76,7 @@ async def initialise(self): self.attributes["hinted_wrong_type"] = self.hinted_wrong_type with pytest.raises(RuntimeError) as excinfo: - Backend(ControllerWithWrongType(), loop) + FastCS(ControllerWithWrongType(), transport_options, loop) assert str(excinfo.value) == ( "Controller 'ControllerWithWrongType' introspection of hinted attribute " "'hinted_wrong_type' does not match defined datatype. " @@ -86,7 +87,7 @@ class ControllerWithMissingAttr(Controller): hinted_int_missing: AttrR[int] with pytest.raises(RuntimeError) as excinfo: - Backend(ControllerWithMissingAttr(), loop) + FastCS(ControllerWithMissingAttr(), transport_options, loop) assert str(excinfo.value) == ( "Controller `ControllerWithMissingAttr` failed to introspect hinted attribute " "`hinted_int_missing` during initialisation" @@ -100,7 +101,7 @@ async def initialise(self): self.attributes["hinted"] = self.hinted with pytest.raises(RuntimeError) as excinfo: - Backend(ControllerAttrWrongAccessMode(), loop) + FastCS(ControllerAttrWrongAccessMode(), transport_options, loop) assert str(excinfo.value) == ( "Controller 'ControllerAttrWrongAccessMode' introspection of hinted attribute " "'hinted' does not match defined access mode. Expected 'AttrR', got 'AttrRW'." @@ -118,7 +119,7 @@ class ControllerWrongEnumClass(Controller): hinted_enum: AttrRW[MyEnum] = AttrRW(Enum(MyEnum2)) with pytest.raises(RuntimeError) as excinfo: - Backend(ControllerWrongEnumClass(), loop) + FastCS(ControllerWrongEnumClass(), transport_options, loop) assert str(excinfo.value) == ( "Controller 'ControllerWrongEnumClass' introspection of hinted attribute " "'hinted_enum' does not match defined datatype. "