diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index 1bcb56382..fce63326a 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from enum import Enum from typing import Any, Generic, Protocol, runtime_checkable @@ -26,7 +27,8 @@ async def put(self, controller: Any, attr: AttrW, value: Any) -> None: class Updater(Protocol): """Protocol for updating the cached readback value of an ``Attribute``.""" - update_period: float + # If update period is None then the attribute will not be updated as a task. + update_period: float | None = None async def update(self, controller: Any, attr: AttrR) -> None: pass @@ -52,6 +54,7 @@ def __init__( group: str | None = None, handler: Any = None, allowed_values: list[T] | None = None, + description: str | None = None, ) -> None: assert ( datatype.dtype in ATTRIBUTE_TYPES @@ -61,6 +64,11 @@ def __init__( self._group = group self.enabled = True self._allowed_values: list[T] | None = allowed_values + self.description = description + + # A callback to use when setting the datatype to a different value, for example + # changing the units on an int. This should be implemented in the backend. + self._update_datatype_callback: Callable[[DataType[T]], None] | None = None @property def datatype(self) -> DataType[T]: @@ -82,6 +90,18 @@ def group(self) -> str | None: def allowed_values(self) -> list[T] | None: return self._allowed_values + def set_update_datatype_callback( + self, callback: Callable[[DataType[T]], None] | None + ) -> None: + self._update_datatype_callback = callback + + def update_datatype(self, datatype: DataType[T]) -> None: + if not isinstance(self._datatype, type(datatype)): + raise ValueError(f"Attribute datatype must be of type {type(datatype)}") + self._datatype = datatype + if self._update_datatype_callback is not None: + self._update_datatype_callback(datatype) + class AttrR(Attribute[T]): """A read-only ``Attribute``.""" @@ -92,7 +112,9 @@ def __init__( access_mode=AttrMode.READ, group: str | None = None, handler: Updater | None = None, + initial_value: T | None = None, allowed_values: list[T] | None = None, + description: str | None = None, ) -> None: super().__init__( datatype, # type: ignore @@ -100,8 +122,11 @@ def __init__( group, handler, allowed_values=allowed_values, # type: ignore + description=description, + ) + self._value: T = ( + datatype.initial_value if initial_value is None else initial_value ) - self._value: T = datatype.dtype() self._update_callback: AttrCallback[T] | None = None self._updater = handler @@ -109,7 +134,7 @@ def get(self) -> T: return self._value async def set(self, value: T) -> None: - self._value = self._datatype.dtype(value) + self._value = self._datatype.cast(value) if self._update_callback is not None: await self._update_callback(self._value) @@ -132,6 +157,7 @@ def __init__( group: str | None = None, handler: Sender | None = None, allowed_values: list[T] | None = None, + description: str | None = None, ) -> None: super().__init__( datatype, # type: ignore @@ -139,6 +165,7 @@ def __init__( group, handler, allowed_values=allowed_values, # type: ignore + description=description, ) self._process_callback: AttrCallback[T] | None = None self._write_display_callback: AttrCallback[T] | None = None @@ -150,11 +177,11 @@ async def process(self, value: T) -> None: async def process_without_display_update(self, value: T) -> None: if self._process_callback is not None: - await self._process_callback(self._datatype.dtype(value)) + await self._process_callback(self._datatype.cast(value)) async def update_display_without_process(self, value: T) -> None: if self._write_display_callback is not None: - await self._write_display_callback(self._datatype.dtype(value)) + await self._write_display_callback(self._datatype.cast(value)) def set_process_callback(self, callback: AttrCallback[T] | None) -> None: self._process_callback = callback @@ -170,7 +197,7 @@ def sender(self) -> Sender | None: return self._sender -class AttrRW(AttrW[T], AttrR[T]): +class AttrRW(AttrR[T], AttrW[T]): """A read-write ``Attribute``.""" def __init__( @@ -179,14 +206,18 @@ def __init__( access_mode=AttrMode.READ_WRITE, group: str | None = None, handler: Handler | None = None, + initial_value: T | None = None, allowed_values: list[T] | None = None, + description: str | None = None, ) -> None: super().__init__( datatype, # type: ignore access_mode, - group, - handler, - allowed_values, # type: ignore + group=group, + handler=handler, + initial_value=initial_value, + allowed_values=allowed_values, # type: ignore + description=description, ) async def process(self, value: T) -> None: diff --git a/src/fastcs/backend.py b/src/fastcs/backend.py index 80c8eab67..ec8a61026 100644 --- a/src/fastcs/backend.py +++ b/src/fastcs/backend.py @@ -1,7 +1,6 @@ import asyncio from collections import defaultdict from collections.abc import Callable -from concurrent.futures import Future from types import MethodType from softioc.asyncio_dispatcher import AsyncioDispatcher @@ -21,7 +20,7 @@ def __init__( self._controller = controller self._initial_tasks = [controller.connect] - self._scan_tasks: list[Future] = [] + self._scan_tasks: list[asyncio.Task] = [] asyncio.run_coroutine_threadsafe( self._controller.initialise(), self._loop @@ -41,10 +40,12 @@ def _link_process_tasks(self): _link_single_controller_put_tasks(single_mapping) _link_attribute_sender_class(single_mapping) + def __del__(self): + self.stop_scan_tasks() + def run(self): self._run_initial_tasks() - self._start_scan_tasks() - + self.start_scan_tasks() self._run() def _run_initial_tasks(self): @@ -52,11 +53,18 @@ def _run_initial_tasks(self): future = asyncio.run_coroutine_threadsafe(task(), self._loop) future.result() - def _start_scan_tasks(self): - scan_tasks = _get_scan_tasks(self._mapping) + def start_scan_tasks(self): + self._scan_tasks = [ + self._loop.create_task(coro()) for coro in _get_scan_coros(self._mapping) + ] - for task in scan_tasks: - asyncio.run_coroutine_threadsafe(task(), self._loop) + def stop_scan_tasks(self): + for task in self._scan_tasks: + if not task.done(): + try: + task.cancel() + except asyncio.CancelledError: + pass def _run(self): raise NotImplementedError("Specific Backend must implement _run") @@ -98,15 +106,15 @@ async def callback(value): return callback -def _get_scan_tasks(mapping: Mapping) -> list[Callable]: +def _get_scan_coros(mapping: Mapping) -> list[Callable]: scan_dict: dict[float, list[Callable]] = defaultdict(list) for single_mapping in mapping.get_controller_mappings(): _add_scan_method_tasks(scan_dict, single_mapping) _add_attribute_updater_tasks(scan_dict, single_mapping) - scan_tasks = _get_periodic_scan_tasks(scan_dict) - return scan_tasks + scan_coros = _get_periodic_scan_coros(scan_dict) + return scan_coros def _add_scan_method_tasks( @@ -124,6 +132,8 @@ def _add_attribute_updater_tasks( for attribute in single_mapping.attributes.values(): match attribute: case AttrR(updater=Updater(update_period=update_period)) as attribute: + if update_period is None: + continue callback = _create_updater_callback( attribute, single_mapping.controller ) @@ -144,18 +154,18 @@ async def callback(): return callback -def _get_periodic_scan_tasks(scan_dict: dict[float, list[Callable]]) -> list[Callable]: - periodic_scan_tasks: list[Callable] = [] +def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]: + periodic_scan_coros: list[Callable] = [] for period, methods in scan_dict.items(): - periodic_scan_tasks.append(_create_periodic_scan_task(period, methods)) + periodic_scan_coros.append(_create_periodic_scan_coro(period, methods)) - return periodic_scan_tasks + return periodic_scan_coros -def _create_periodic_scan_task(period, methods: list[Callable]) -> Callable: - async def scan_task() -> None: +def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable: + async def scan_coro() -> None: while True: await asyncio.gather(*[method() for method in methods]) await asyncio.sleep(period) - return scan_task + return scan_coro diff --git a/src/fastcs/backends/epics/backend.py b/src/fastcs/backends/epics/backend.py index 203731838..b34eff8fc 100644 --- a/src/fastcs/backends/epics/backend.py +++ b/src/fastcs/backends/epics/backend.py @@ -7,17 +7,26 @@ class EpicsBackend(Backend): - def __init__(self, controller: Controller, pv_prefix: str = "MY-DEVICE-PREFIX"): + def __init__( + self, + controller: Controller, + pv_prefix: str = "MY-DEVICE-PREFIX", + ioc_options: EpicsIOCOptions | None = None, + ): super().__init__(controller) self._pv_prefix = pv_prefix - self._ioc = EpicsIOC(pv_prefix, self._mapping) + self.ioc_options = ioc_options or EpicsIOCOptions() + self._ioc = EpicsIOC(pv_prefix, self._mapping, options=ioc_options) - def create_docs(self, options: EpicsDocsOptions | None = None) -> None: - EpicsDocs(self._mapping).create_docs(options) + def create_docs(self, docs_options: EpicsDocsOptions | None = None) -> None: + EpicsDocs(self._mapping).create_docs(docs_options) - def create_gui(self, options: EpicsGUIOptions | None = None) -> None: - EpicsGUI(self._mapping, self._pv_prefix).create_gui(options) + def create_gui(self, gui_options: EpicsGUIOptions | None = None) -> None: + assert self.ioc_options.name_options is not None + EpicsGUI( + self._mapping, self._pv_prefix, self.ioc_options.name_options + ).create_gui(gui_options) - def _run(self, options: EpicsIOCOptions | None = None): - self._ioc.run(self._dispatcher, self._context, options) + def _run(self): + self._ioc.run(self._dispatcher, self._context) diff --git a/src/fastcs/backends/epics/gui.py b/src/fastcs/backends/epics/gui.py index b9c48751a..4a040ec67 100644 --- a/src/fastcs/backends/epics/gui.py +++ b/src/fastcs/backends/epics/gui.py @@ -27,6 +27,10 @@ from pydantic import ValidationError from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW +from fastcs.backends.epics.util import ( + EpicsNameOptions, + _convert_attribute_name_to_pv_name, +) from fastcs.cs_methods import Command from fastcs.datatypes import Bool, Float, Int, String from fastcs.exceptions import FastCSException @@ -39,7 +43,7 @@ class EpicsGUIFormat(Enum): edl = ".edl" -@dataclass +@dataclass(frozen=True) class EpicsGUIOptions: output_path: Path = Path.cwd() / "output.bob" file_format: EpicsGUIFormat = EpicsGUIFormat.bob @@ -47,13 +51,37 @@ class EpicsGUIOptions: class EpicsGUI: - def __init__(self, mapping: Mapping, pv_prefix: str) -> None: + def __init__( + self, + mapping: Mapping, + pv_prefix: str, + epics_name_options: EpicsNameOptions | None = None, + ) -> None: self._mapping = mapping self._pv_prefix = pv_prefix + self.epics_name_options = epics_name_options or EpicsNameOptions() def _get_pv(self, attr_path: list[str], name: str): - attr_prefix = ":".join([self._pv_prefix] + attr_path) - return f"{attr_prefix}:{name.title().replace('_', '')}" + return self.epics_name_options.pv_separator.join( + [ + self._pv_prefix, + ] + + [ + _convert_attribute_name_to_pv_name( + attr_name, + self.epics_name_options.pv_naming_convention, + is_attribute=False, + ) + for attr_name in attr_path + ] + + [ + _convert_attribute_name_to_pv_name( + name, + self.epics_name_options.pv_naming_convention, + is_attribute=True, + ), + ], + ) @staticmethod def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion: diff --git a/src/fastcs/backends/epics/ioc.py b/src/fastcs/backends/epics/ioc.py index 7aa0e8924..a3bedd13b 100644 --- a/src/fastcs/backends/epics/ioc.py +++ b/src/fastcs/backends/epics/ioc.py @@ -1,53 +1,305 @@ from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import asdict, dataclass from types import MethodType from typing import Any, Literal -from softioc import builder, softioc +import numpy as np +from softioc import builder, fields, softioc from softioc.asyncio_dispatcher import AsyncioDispatcher +from softioc.imports import db_put_field from softioc.pythonSoftIoc import RecordWrapper from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.backends.epics.util import ( MBB_STATE_FIELDS, + EpicsNameOptions, + _convert_attribute_name_to_pv_name, attr_is_enum, enum_index_to_value, enum_value_to_index, ) from fastcs.controller import BaseController -from fastcs.datatypes import Bool, Float, Int, String, T +from fastcs.datatypes import Bool, DataType, Float, Int, String, T, WaveForm from fastcs.exceptions import FastCSException from fastcs.mapping import Mapping EPICS_MAX_NAME_LENGTH = 60 -@dataclass +@dataclass(frozen=True) class EpicsIOCOptions: terminal: bool = True + name_options: EpicsNameOptions = EpicsNameOptions() + + +DATATYPE_NAME_TO_RECORD_FIELD = { + "prec": "PREC", + "units": "EGU", + "min": "DRVL", + "max": "DRVH", + "min_alarm": "LOPR", + "max_alarm": "HOPR", + "znam": "ZNAM", + "onam": "ONAM", + "length": "length", +} + + +def datatype_to_epics_fields(datatype: DataType) -> dict[str, Any]: + return { + DATATYPE_NAME_TO_RECORD_FIELD[field]: value + for field, value in asdict(datatype).items() + } + + +def reload_attribute_fields(pv: str, DataType: DataType): + """If the ioc side changes a field on the attribute + e.g ``units`` then this method will update it on the attribute""" + + for name, value in datatype_to_epics_fields(DataType): + # TODO: can we just make every dtype a string and have the ioc convert? + array = np.require(value, dtype=np.dtype("S40")) + db_put_field(f"{pv}.{name}", fields.DBR_STRING, array, 1) class EpicsIOC: - def __init__(self, pv_prefix: str, mapping: Mapping): - _add_pvi_info(f"{pv_prefix}:PVI") - _add_sub_controller_pvi_info(pv_prefix, mapping.controller) + def __init__( + self, pv_prefix: str, mapping: Mapping, options: EpicsIOCOptions | None = None + ): + self._options = options or EpicsIOCOptions() + self._name_options = self._options.name_options + + _add_pvi_info(f"{pv_prefix}{self._name_options.pv_separator}PVI") + self._add_sub_controller_pvi_info(pv_prefix, mapping.controller) - _create_and_link_attribute_pvs(pv_prefix, mapping) - _create_and_link_command_pvs(pv_prefix, mapping) + self._create_and_link_attribute_pvs(pv_prefix, mapping) + self._create_and_link_command_pvs(pv_prefix, mapping) def run( self, dispatcher: AsyncioDispatcher, context: dict[str, Any], - options: EpicsIOCOptions | None = None, ) -> None: - if options is None: - options = EpicsIOCOptions() - builder.LoadDatabase() softioc.iocInit(dispatcher) - softioc.interactive_ioc(context) + if self._options.terminal: + softioc.interactive_ioc(context) + + def _add_sub_controller_pvi_info(self, pv_prefix: str, parent: BaseController): + """Add PVI references from controller to its sub controllers, recursively. + + Args: + pv_prefix: PV Prefix of IOC + parent: Controller to add PVI refs for + + """ + parent_pvi = self._name_options.pv_separator.join( + [pv_prefix] + parent.path + ["PVI"] + ) + + for child in parent.get_sub_controllers().values(): + child_pvi = self._name_options.pv_separator.join( + [pv_prefix] + + [ + _convert_attribute_name_to_pv_name( + path, + self._name_options.pv_naming_convention, + is_attribute=False, + ) + for path in child.path + ] + + ["PVI"] + ) + child_name = child.path[-1].lower() + + _add_pvi_info(child_pvi, parent_pvi, child_name) + + self._add_sub_controller_pvi_info(pv_prefix, child) + + def _create_and_link_attribute_pvs(self, pv_prefix: str, mapping: Mapping) -> None: + for single_mapping in mapping.get_controller_mappings(): + formatted_path = [ + _convert_attribute_name_to_pv_name( + p, self._name_options.pv_naming_convention, is_attribute=False + ) + for p in single_mapping.controller.path + ] + for attr_name, attribute in single_mapping.attributes.items(): + pv_name = _convert_attribute_name_to_pv_name( + attr_name, + self._name_options.pv_naming_convention, + is_attribute=True, + ) + _pv_prefix = self._name_options.pv_separator.join( + [pv_prefix] + formatted_path + ) + full_pv_name_length = len( + f"{_pv_prefix}{self._name_options.pv_separator}{pv_name}" + ) + + if full_pv_name_length > EPICS_MAX_NAME_LENGTH: + attribute.enabled = False + print( + f"Not creating PV for {attr_name} for controller" + f" {single_mapping.controller.path} as full name would exceed" + f" {EPICS_MAX_NAME_LENGTH} characters" + ) + continue + + match attribute: + case AttrRW(): + if full_pv_name_length > (EPICS_MAX_NAME_LENGTH - 4): + print( + f"Not creating PVs for {attr_name} as _RBV PV" + f" name would exceed {EPICS_MAX_NAME_LENGTH}" + " characters" + ) + attribute.enabled = False + else: + self._create_and_link_read_pv( + _pv_prefix, f"{pv_name}_RBV", attr_name, attribute + ) + self._create_and_link_write_pv( + _pv_prefix, pv_name, attr_name, attribute + ) + case AttrR(): + self._create_and_link_read_pv( + _pv_prefix, pv_name, attr_name, attribute + ) + case AttrW(): + self._create_and_link_write_pv( + _pv_prefix, pv_name, attr_name, attribute + ) + + def _create_and_link_read_pv( + self, pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T] + ) -> None: + if attr_is_enum(attribute): + + async def async_record_set(value: T): + record.set(enum_value_to_index(attribute, value)) + + else: + + async def async_record_set(value: T): + record.set(value) + + record = _get_input_record( + f"{pv_prefix}{self._name_options.pv_separator}{pv_name}", attribute + ) + self._add_attr_pvi_info(record, pv_prefix, attr_name, "r") + + attribute.set_update_callback(async_record_set) + + def _create_and_link_command_pvs(self, pv_prefix: str, mapping: Mapping) -> None: + for single_mapping in mapping.get_controller_mappings(): + formatted_path = [ + _convert_attribute_name_to_pv_name( + p, self._name_options.pv_naming_convention, is_attribute=False + ) + for p in single_mapping.controller.path + ] + for attr_name, method in single_mapping.command_methods.items(): + pv_name = _convert_attribute_name_to_pv_name( + attr_name, + self._name_options.pv_naming_convention, + is_attribute=True, + ) + _pv_prefix = self._name_options.pv_separator.join( + [pv_prefix] + formatted_path + ) + if ( + len(f"{_pv_prefix}{self._name_options.pv_separator}{pv_name}") + > EPICS_MAX_NAME_LENGTH + ): + print( + f"Not creating PV for {attr_name} as full name would exceed" + f" {EPICS_MAX_NAME_LENGTH} characters" + ) + method.enabled = False + else: + self._create_and_link_command_pv( + _pv_prefix, + pv_name, + attr_name, + MethodType(method.fn, single_mapping.controller), + ) + + def _create_and_link_write_pv( + self, pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T] + ) -> None: + if attr_is_enum(attribute): + + async def on_update(value): + await attribute.process_without_display_update( + enum_index_to_value(attribute, value) + ) + + async def async_write_display(value: T): + record.set(enum_value_to_index(attribute, value), process=False) + + else: + + async def on_update(value): + await attribute.process_without_display_update(value) + + async def async_write_display(value: T): + record.set(value, process=False) + + record = _get_output_record( + f"{pv_prefix}{self._name_options.pv_separator}{pv_name}", + attribute, + on_update=on_update, + ) + + self._add_attr_pvi_info(record, pv_prefix, attr_name, "w") + + attribute.set_write_display_callback(async_write_display) + + def _create_and_link_command_pv( + self, pv_prefix: str, pv_name: str, attr_name: str, method: Callable + ) -> None: + async def wrapped_method(_: Any): + await method() + + record = builder.aOut( + f"{pv_prefix}{self._name_options.pv_separator}{pv_name}", + initial_value=0, + always_update=True, + on_update=wrapped_method, + ) + + self._add_attr_pvi_info(record, pv_prefix, attr_name, "x") + + def _add_attr_pvi_info( + self, + record: RecordWrapper, + prefix: str, + name: str, + access_mode: Literal["r", "w", "rw", "x"], + ): + """Add an info tag to a record to include it in the PVI for the controller. + + Args: + record: Record to add info tag to + prefix: PV prefix of controller + name: Name of parameter to add to PVI + access_mode: Access mode of parameter + + """ + record.add_info( + "Q:group", + { + f"{prefix}{self._name_options.pv_separator}PVI": { + f"value.{name}.{access_mode}": { + "+channel": "NAME", + "+type": "plain", + "+trigger": f"value.{name}.{access_mode}", + } + } + }, + ) def _add_pvi_info( @@ -95,224 +347,113 @@ def _add_pvi_info( record.add_info("Q:group", q_group) -def _add_sub_controller_pvi_info(pv_prefix: str, parent: BaseController): - """Add PVI references from controller to its sub controllers, recursively. - - Args: - pv_prefix: PV Prefix of IOC - parent: Controller to add PVI refs for - - """ - parent_pvi = ":".join([pv_prefix] + parent.path + ["PVI"]) - - for child in parent.get_sub_controllers().values(): - child_pvi = ":".join([pv_prefix] + child.path + ["PVI"]) - child_name = child.path[-1].lower() - - _add_pvi_info(child_pvi, parent_pvi, child_name) - - _add_sub_controller_pvi_info(pv_prefix, child) - - -def _create_and_link_attribute_pvs(pv_prefix: str, mapping: Mapping) -> None: - for single_mapping in mapping.get_controller_mappings(): - path = single_mapping.controller.path - for attr_name, attribute in single_mapping.attributes.items(): - pv_name = attr_name.title().replace("_", "") - _pv_prefix = ":".join([pv_prefix] + path) - full_pv_name_length = len(f"{_pv_prefix}:{pv_name}") - - if full_pv_name_length > EPICS_MAX_NAME_LENGTH: - attribute.enabled = False - print( - f"Not creating PV for {attr_name} for controller" - f" {single_mapping.controller.path} as full name would exceed" - f" {EPICS_MAX_NAME_LENGTH} characters" - ) - continue - - match attribute: - case AttrRW(): - if full_pv_name_length > (EPICS_MAX_NAME_LENGTH - 4): - print( - f"Not creating PVs for {attr_name} as _RBV PV" - f" name would exceed {EPICS_MAX_NAME_LENGTH}" - " characters" - ) - attribute.enabled = False - else: - _create_and_link_read_pv( - _pv_prefix, f"{pv_name}_RBV", attr_name, attribute - ) - _create_and_link_write_pv( - _pv_prefix, pv_name, attr_name, attribute - ) - case AttrR(): - _create_and_link_read_pv(_pv_prefix, pv_name, attr_name, attribute) - case AttrW(): - _create_and_link_write_pv(_pv_prefix, pv_name, attr_name, attribute) - - -def _create_and_link_read_pv( - pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T] -) -> None: - if attr_is_enum(attribute): - - async def async_record_set(value: T): - record.set(enum_value_to_index(attribute, value)) - else: - - async def async_record_set(value: T): - record.set(value) - - record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute) - _add_attr_pvi_info(record, pv_prefix, attr_name, "r") - - attribute.set_update_callback(async_record_set) - - def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper: + attribute_fields = {} + if attribute.description is not None: + attribute_fields.update({"DESC": attribute.description}) + if attr_is_enum(attribute): assert attribute.allowed_values is not None and all( isinstance(v, str) for v in attribute.allowed_values ) state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) - return builder.mbbIn(pv, **state_keys) + return builder.mbbIn(pv, **state_keys, **attribute_fields) + + def datatype_updater(datatype: DataType): + reload_attribute_fields(pv, datatype) + + attribute.set_update_datatype_callback(datatype_updater) match attribute.datatype: - case Bool(znam, onam): - return builder.boolIn(pv, ZNAM=znam, ONAM=onam) + case Bool(): + return builder.boolIn( + pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields + ) case Int(): - return builder.longIn(pv) - case Float(prec): - return builder.aIn(pv, PREC=prec) + return builder.longIn( + pv, + **datatype_to_epics_fields(attribute.datatype), + **attribute_fields, + ) + case Float(): + return builder.aIn( + pv, + **datatype_to_epics_fields(attribute.datatype), + **attribute_fields, + ) case String(): - return builder.longStringIn(pv) + return builder.longStringIn( + pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields + ) + case WaveForm(): + return builder.WaveformIn( + pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields + ) case _: raise FastCSException( f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" ) -def _create_and_link_write_pv( - pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T] -) -> None: - if attr_is_enum(attribute): - - async def on_update(value): - await attribute.process_without_display_update( - enum_index_to_value(attribute, value) - ) - - async def async_write_display(value: T): - record.set(enum_value_to_index(attribute, value), process=False) - - else: - - async def on_update(value): - await attribute.process_without_display_update(value) - - async def async_write_display(value: T): - record.set(value, process=False) - - record = _get_output_record( - f"{pv_prefix}:{pv_name}", attribute, on_update=on_update - ) - _add_attr_pvi_info(record, pv_prefix, attr_name, "w") - - attribute.set_write_display_callback(async_write_display) - - def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any: + attribute_fields = {} + if attribute.description is not None: + attribute_fields.update({"DESC": attribute.description}) if attr_is_enum(attribute): assert attribute.allowed_values is not None and all( - isinstance(v, str) for v in attribute.allowed_values + isinstance(v, str) or isinstance(v, int) for v in attribute.allowed_values ) state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) - return builder.mbbOut(pv, always_update=True, on_update=on_update, **state_keys) + return builder.mbbOut( + pv, + always_update=True, + on_update=on_update, + **state_keys, + **attribute_fields, + ) + + def datatype_updater(datatype: DataType): + reload_attribute_fields(pv, datatype) + + attribute.set_update_datatype_callback(datatype_updater) match attribute.datatype: - case Bool(znam, onam): + case Bool(): return builder.boolOut( pv, - ZNAM=znam, - ONAM=onam, + **datatype_to_epics_fields(attribute.datatype), always_update=True, on_update=on_update, ) case Int(): - return builder.longOut(pv, always_update=True, on_update=on_update) - case Float(prec): - return builder.aOut(pv, always_update=True, on_update=on_update, PREC=prec) + return builder.longOut( + pv, + always_update=True, + on_update=on_update, + **datatype_to_epics_fields(attribute.datatype), + **attribute_fields, + ) + case Float(): + return builder.aOut( + pv, + always_update=True, + on_update=on_update, + **datatype_to_epics_fields(attribute.datatype), + **attribute_fields, + ) case String(): - return builder.longStringOut(pv, always_update=True, on_update=on_update) + return builder.longStringOut( + pv, + always_update=True, + on_update=on_update, + **datatype_to_epics_fields(attribute.datatype), + **attribute_fields, + ) + case WaveForm(): + return builder.WaveformOut( + pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields + ) case _: raise FastCSException( f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" ) - - -def _create_and_link_command_pvs(pv_prefix: str, mapping: Mapping) -> None: - for single_mapping in mapping.get_controller_mappings(): - path = single_mapping.controller.path - for attr_name, method in single_mapping.command_methods.items(): - pv_name = attr_name.title().replace("_", "") - _pv_prefix = ":".join([pv_prefix] + path) - if len(f"{_pv_prefix}:{pv_name}") > EPICS_MAX_NAME_LENGTH: - print( - f"Not creating PV for {attr_name} as full name would exceed" - f" {EPICS_MAX_NAME_LENGTH} characters" - ) - method.enabled = False - else: - _create_and_link_command_pv( - _pv_prefix, - pv_name, - attr_name, - MethodType(method.fn, single_mapping.controller), - ) - - -def _create_and_link_command_pv( - pv_prefix: str, pv_name: str, attr_name: str, method: Callable -) -> None: - async def wrapped_method(_: Any): - await method() - - record = builder.aOut( - f"{pv_prefix}:{pv_name}", - initial_value=0, - always_update=True, - on_update=wrapped_method, - ) - - _add_attr_pvi_info(record, pv_prefix, attr_name, "x") - - -def _add_attr_pvi_info( - record: RecordWrapper, - prefix: str, - name: str, - access_mode: Literal["r", "w", "rw", "x"], -): - """Add an info tag to a record to include it in the PVI for the controller. - - Args: - record: Record to add info tag to - prefix: PV prefix of controller - name: Name of parameter to add to PVI - access_mode: Access mode of parameter - - """ - record.add_info( - "Q:group", - { - f"{prefix}:PVI": { - f"value.{name}.{access_mode}": { - "+channel": "NAME", - "+type": "plain", - "+trigger": f"value.{name}.{access_mode}", - } - } - }, - ) diff --git a/src/fastcs/backends/epics/util.py b/src/fastcs/backends/epics/util.py index b1ffa6089..973cef43a 100644 --- a/src/fastcs/backends/epics/util.py +++ b/src/fastcs/backends/epics/util.py @@ -1,3 +1,6 @@ +from dataclasses import dataclass +from enum import Enum + from fastcs.attributes import Attribute from fastcs.datatypes import String, T @@ -25,6 +28,42 @@ MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES) +class PvNamingConvention(Enum): + NO_CONVERSION = "NO_CONVERSION" + PASCAL = "PASCAL" + CAPITALIZED = "CAPITALIZED" + CAPITALIZED_CONTROLLER_PASCAL_ATTRIBUTE = "CAPITALIZED_CONTROLLER_PASCAL_ATTRIBUTE" + + +DEFAULT_PV_SEPARATOR = ":" + + +@dataclass(frozen=True) +class EpicsNameOptions: + pv_naming_convention: PvNamingConvention = PvNamingConvention.PASCAL + pv_separator: str = DEFAULT_PV_SEPARATOR + + +def _convert_attribute_name_to_pv_name( + attr_name: str, naming_convention: PvNamingConvention, is_attribute: bool = False +) -> str: + if naming_convention == PvNamingConvention.PASCAL: + return attr_name.title().replace("_", "") + elif naming_convention == PvNamingConvention.CAPITALIZED: + return attr_name.upper().replace("_", "-") + elif ( + naming_convention == PvNamingConvention.CAPITALIZED_CONTROLLER_PASCAL_ATTRIBUTE + ): + if is_attribute: + return _convert_attribute_name_to_pv_name( + attr_name, PvNamingConvention.PASCAL, is_attribute + ) + return _convert_attribute_name_to_pv_name( + attr_name, PvNamingConvention.CAPITALIZED + ) + return attr_name + + def attr_is_enum(attribute: Attribute) -> bool: """Check if the `Attribute` has a `String` datatype and has `allowed_values` set. @@ -36,9 +75,9 @@ def attr_is_enum(attribute: Attribute) -> bool: """ match attribute: - case Attribute( - datatype=String(), allowed_values=allowed_values - ) if allowed_values is not None and len(allowed_values) <= MBB_MAX_CHOICES: + case Attribute(datatype=String(), allowed_values=allowed_values) if ( + allowed_values is not None and len(allowed_values) <= MBB_MAX_CHOICES + ): return True case _: return False diff --git a/src/fastcs/backends/tango/dsr.py b/src/fastcs/backends/tango/dsr.py index d8689e466..ea5e8365c 100644 --- a/src/fastcs/backends/tango/dsr.py +++ b/src/fastcs/backends/tango/dsr.py @@ -35,7 +35,7 @@ async def fget(tango_device: Device): def _tango_polling_period(attribute: AttrR) -> int: - if attribute.updater is not None: + if attribute.updater is not None and attribute.updater.update_period is not None: # Convert to integer milliseconds return int(attribute.updater.update_period * 1000) diff --git a/src/fastcs/controller.py b/src/fastcs/controller.py index 95c64d31f..882cbbb8c 100755 --- a/src/fastcs/controller.py +++ b/src/fastcs/controller.py @@ -6,12 +6,23 @@ class BaseController: - def __init__(self, path: list[str] | None = None) -> None: + def __init__( + self, path: list[str] | None = None, search_device_for_attributes: bool = True + ) -> None: self._path: list[str] = path or [] - self.__sub_controller_tree: dict[str, BaseController] = {} + # If this is set to `False`, FastCS will only use attributes defined in + # `additional_attributes`. + self.search_device_for_attributes = search_device_for_attributes + self.__sub_controller_tree: dict[str, BaseController] = {} self._bind_attrs() + @property + def additional_attributes(self) -> dict[str, Attribute]: + """FastCS will look for attributes on the controller, but additional attributes + are provided by this method.""" + return {} + @property def path(self) -> list[str]: """Path prefix of attributes, recursively including parent ``Controller``s.""" @@ -52,8 +63,8 @@ class Controller(BaseController): generating a UI or creating parameters for a control system. """ - def __init__(self) -> None: - super().__init__() + def __init__(self, search_device_for_attributes: bool = True) -> None: + super().__init__(search_device_for_attributes=search_device_for_attributes) async def initialise(self) -> None: pass @@ -69,5 +80,5 @@ class SubController(BaseController): it as part of a larger device. """ - def __init__(self) -> None: - super().__init__() + def __init__(self, search_device_for_attributes: bool = True) -> None: + super().__init__(search_device_for_attributes=search_device_for_attributes) diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index ccb361d07..6a264273f 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -1,23 +1,39 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass -from typing import Generic, TypeVar +from typing import Any, Generic, Literal, TypeVar -T = TypeVar("T", int, float, bool, str) +import numpy as np + +T = TypeVar("T", int, float, bool, str, np.ndarray) # type: ignore ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore AttrCallback = Callable[[T], Awaitable[None]] +@dataclass(frozen=True) # So that we can type hint with dataclass methods class DataType(Generic[T]): """Generic datatype mapping to a python type, with additional metadata.""" @property @abstractmethod - def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars + def dtype( + self, + ) -> type[T]: # Using property due to lack of Generic ClassVars + pass + + @property + @abstractmethod + def initial_value(self) -> T: + """Return an initial value for the datatype.""" + pass + + @abstractmethod + def cast(self, value: Any) -> T: + """Cast a value to the datatype to put to the backend.""" pass @@ -25,21 +41,54 @@ def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars class Int(DataType[int]): """`DataType` mapping to builtin ``int``.""" + units: str | None = None + min: int | None = None + max: int | None = None + min_alarm: int | None = None + max_alarm: int | None = None + @property def dtype(self) -> type[int]: return int + @property + def initial_value(self) -> Literal[0]: + return 0 + + def cast(self, value: Any) -> int: + if self.min is not None and value < self.min: + raise ValueError(f"Value {value} is less than minimum {self.min}") + if self.max is not None and value > self.max: + raise ValueError(f"Value {value} is greater than maximum {self.max}") + return int(value) + @dataclass(frozen=True) class Float(DataType[float]): """`DataType` mapping to builtin ``float``.""" prec: int = 2 + units: str | None = None + min: float | None = None + max: float | None = None + min_alarm: float | None = None + max_alarm: float | None = None @property def dtype(self) -> type[float]: return float + @property + def intial_value(self) -> float: + return 0.0 + + def cast(self, value: Any) -> float: + if self.min is not None and value < self.min: + raise ValueError(f"Value {value} is less than minimum {self.min}") + if self.max is not None and value > self.max: + raise ValueError(f"Value {value} is greater than maximum {self.max}") + return float(value) + @dataclass(frozen=True) class Bool(DataType[bool]): @@ -52,6 +101,13 @@ class Bool(DataType[bool]): def dtype(self) -> type[bool]: return bool + @property + def intial_value(self) -> Literal[False]: + return False + + def cast(self, value: Any) -> bool: + return bool(value) + @dataclass(frozen=True) class String(DataType[str]): @@ -60,3 +116,46 @@ class String(DataType[str]): @property def dtype(self) -> type[str]: return str + + @property + def intial_value(self) -> Literal[""]: + return "" + + def cast(self, value: Any) -> str: + return str(value) + + +DEFAULT_WAVEFORM_LENGTH = 20000 + + +@dataclass(frozen=True) +class WaveForm(DataType[np.ndarray]): + """ + DataType for a waveform, values are of the numpy `datatype` + """ + + numpy_datatype: np.dtype + length: int = DEFAULT_WAVEFORM_LENGTH + + @property + def dtype(self) -> type[np.ndarray]: + return np.ndarray + + @property + def initial_value(self) -> np.ndarray: + return np.ndarray(self.length, dtype=self.numpy_datatype) + + def cast(self, value: Sequence | np.ndarray) -> np.ndarray: + if len(value) > self.length: + raise ValueError( + f"Waveform length {len(value)} is greater than maximum {self.length}." + ) + if isinstance(value, np.ndarray): + if value.dtype != self.numpy_datatype: + raise ValueError( + f"Waveform dtype {value.dtype} does not " + f"match {self.numpy_datatype}." + ) + return value + else: + return np.array(value, dtype=self.numpy_datatype) diff --git a/src/fastcs/mapping.py b/src/fastcs/mapping.py index d15cb481b..17e0d1bbd 100644 --- a/src/fastcs/mapping.py +++ b/src/fastcs/mapping.py @@ -52,7 +52,17 @@ def _get_single_mapping(controller: BaseController) -> SingleMapping: case WrappedMethod(fastcs_method=Command(enabled=True) as command_method): command_methods[attr_name] = command_method case Attribute(enabled=True): - attributes[attr_name] = attr + if controller.search_device_for_attributes: + attributes[attr_name] = attr + + additional_attributes = controller.additional_attributes + if common_attributes := additional_attributes.keys() & attributes.keys(): + raise RuntimeError( + f"Received additional attributes {common_attributes} " + f"already present in the controller {controller}." + ) + + attributes.update(additional_attributes) return SingleMapping( controller, scan_methods, put_methods, command_methods, attributes diff --git a/tests/backends/epics/test_ioc.py b/tests/backends/epics/test_ioc.py index c1a5b3218..4bac2fcd1 100644 --- a/tests/backends/epics/test_ioc.py +++ b/tests/backends/epics/test_ioc.py @@ -7,11 +7,7 @@ from fastcs.backends.epics.ioc import ( EPICS_MAX_NAME_LENGTH, EpicsIOC, - _add_attr_pvi_info, _add_pvi_info, - _add_sub_controller_pvi_info, - _create_and_link_read_pv, - _create_and_link_write_pv, _get_input_record, _get_output_record, ) @@ -27,17 +23,31 @@ ONOFF_STATES = {"ZRST": "disabled", "ONST": "enabled"} +@pytest.fixture +def ioc_without_mapping(mocker: MockerFixture, mapping: Mapping): + mocker.patch("fastcs.backends.epics.ioc.builder") + mocker.patch("fastcs.backends.epics.ioc.EpicsIOC._create_and_link_attribute_pvs") + mocker.patch("fastcs.backends.epics.ioc.EpicsIOC._create_and_link_command_pvs") + + return EpicsIOC(DEVICE, mapping) + + @pytest.mark.asyncio -async def test_create_and_link_read_pv(mocker: MockerFixture): +async def test_create_and_link_read_pv( + mocker: MockerFixture, ioc_without_mapping: EpicsIOC +): get_input_record = mocker.patch("fastcs.backends.epics.ioc._get_input_record") - add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info") attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum") + mocker.patch("fastcs.backends.epics.ioc._add_pvi_info") + add_attr_pvi_info = mocker.patch( + "fastcs.backends.epics.ioc.EpicsIOC._add_attr_pvi_info" + ) record = get_input_record.return_value attribute = mocker.MagicMock() - attr_is_enum.return_value = False - _create_and_link_read_pv("PREFIX", "PV", "attr", attribute) + + ioc_without_mapping._create_and_link_read_pv("PREFIX", "PV", "attr", attribute) get_input_record.assert_called_once_with("PREFIX:PV", attribute) add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r") @@ -51,9 +61,13 @@ async def test_create_and_link_read_pv(mocker: MockerFixture): @pytest.mark.asyncio -async def test_create_and_link_read_pv_enum(mocker: MockerFixture): +async def test_create_and_link_read_pv_enum( + mocker: MockerFixture, ioc_without_mapping: EpicsIOC +): get_input_record = mocker.patch("fastcs.backends.epics.ioc._get_input_record") - add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info") + add_attr_pvi_info = mocker.patch( + "fastcs.backends.epics.ioc.EpicsIOC._add_attr_pvi_info" + ) attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum") record = get_input_record.return_value enum_value_to_index = mocker.patch("fastcs.backends.epics.ioc.enum_value_to_index") @@ -61,7 +75,7 @@ async def test_create_and_link_read_pv_enum(mocker: MockerFixture): attribute = mocker.MagicMock() attr_is_enum.return_value = True - _create_and_link_read_pv("PREFIX", "PV", "attr", attribute) + ioc_without_mapping._create_and_link_read_pv("PREFIX", "PV", "attr", attribute) get_input_record.assert_called_once_with("PREFIX:PV", attribute) add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r") @@ -108,9 +122,13 @@ def test_get_input_record_raises(mocker: MockerFixture): @pytest.mark.asyncio -async def test_create_and_link_write_pv(mocker: MockerFixture): +async def test_create_and_link_write_pv( + mocker: MockerFixture, ioc_without_mapping: EpicsIOC +): get_output_record = mocker.patch("fastcs.backends.epics.ioc._get_output_record") - add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info") + add_attr_pvi_info = mocker.patch( + "fastcs.backends.epics.ioc.EpicsIOC._add_attr_pvi_info" + ) attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum") record = get_output_record.return_value @@ -118,7 +136,7 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): attribute.process_without_display_update = mocker.AsyncMock() attr_is_enum.return_value = False - _create_and_link_write_pv("PREFIX", "PV", "attr", attribute) + ioc_without_mapping._create_and_link_write_pv("PREFIX", "PV", "attr", attribute) get_output_record.assert_called_once_with( "PREFIX:PV", attribute, on_update=mocker.ANY @@ -140,9 +158,13 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): @pytest.mark.asyncio -async def test_create_and_link_write_pv_enum(mocker: MockerFixture): +async def test_create_and_link_write_pv_enum( + mocker: MockerFixture, ioc_without_mapping: EpicsIOC +): get_output_record = mocker.patch("fastcs.backends.epics.ioc._get_output_record") - add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info") + add_attr_pvi_info = mocker.patch( + "fastcs.backends.epics.ioc.EpicsIOC._add_attr_pvi_info" + ) attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum") enum_value_to_index = mocker.patch("fastcs.backends.epics.ioc.enum_value_to_index") enum_index_to_value = mocker.patch("fastcs.backends.epics.ioc.enum_index_to_value") @@ -152,7 +174,7 @@ async def test_create_and_link_write_pv_enum(mocker: MockerFixture): attribute.process_without_display_update = mocker.AsyncMock() attr_is_enum.return_value = True - _create_and_link_write_pv("PREFIX", "PV", "attr", attribute) + ioc_without_mapping._create_and_link_write_pv("PREFIX", "PV", "attr", attribute) get_output_record.assert_called_once_with( "PREFIX:PV", attribute, on_update=mocker.ANY @@ -211,26 +233,46 @@ def test_get_output_record_raises(mocker: MockerFixture): _get_output_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) +DEFAULT_SCALAR_FIELD_ARGS = { + "EGU": None, + "DRVL": None, + "DRVH": None, + "LOPR": None, + "HOPR": None, +} + + def test_ioc(mocker: MockerFixture, mapping: Mapping): builder = mocker.patch("fastcs.backends.epics.ioc.builder") add_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_pvi_info") add_sub_controller_pvi_info = mocker.patch( - "fastcs.backends.epics.ioc._add_sub_controller_pvi_info" + "fastcs.backends.epics.ioc.EpicsIOC._add_sub_controller_pvi_info" ) EpicsIOC(DEVICE, mapping) # Check records are created builder.boolIn.assert_called_once_with(f"{DEVICE}:ReadBool", ZNAM="OFF", ONAM="ON") - builder.longIn.assert_any_call(f"{DEVICE}:ReadInt") - builder.aIn.assert_called_once_with(f"{DEVICE}:ReadWriteFloat_RBV", PREC=2) + builder.longIn.assert_any_call(f"{DEVICE}:ReadInt", **DEFAULT_SCALAR_FIELD_ARGS) + builder.aIn.assert_called_once_with( + f"{DEVICE}:ReadWriteFloat_RBV", PREC=2, **DEFAULT_SCALAR_FIELD_ARGS + ) builder.aOut.assert_any_call( - f"{DEVICE}:ReadWriteFloat", always_update=True, on_update=mocker.ANY, PREC=2 + f"{DEVICE}:ReadWriteFloat", + always_update=True, + on_update=mocker.ANY, + PREC=2, + **DEFAULT_SCALAR_FIELD_ARGS, + ) + builder.longIn.assert_any_call(f"{DEVICE}:BigEnum", **DEFAULT_SCALAR_FIELD_ARGS) + builder.longIn.assert_any_call( + f"{DEVICE}:ReadWriteInt_RBV", **DEFAULT_SCALAR_FIELD_ARGS ) - builder.longIn.assert_any_call(f"{DEVICE}:BigEnum") - builder.longIn.assert_any_call(f"{DEVICE}:ReadWriteInt_RBV") builder.longOut.assert_called_with( - f"{DEVICE}:ReadWriteInt", always_update=True, on_update=mocker.ANY + f"{DEVICE}:ReadWriteInt", + always_update=True, + on_update=mocker.ANY, + **DEFAULT_SCALAR_FIELD_ARGS, ) builder.mbbIn.assert_called_once_with( f"{DEVICE}:StringEnum_RBV", ZRST="red", ONST="green", TWST="blue" @@ -323,7 +365,9 @@ def test_add_pvi_info_with_parent(mocker: MockerFixture): ) -def test_add_sub_controller_pvi_info(mocker: MockerFixture): +def test_add_sub_controller_pvi_info( + mocker: MockerFixture, ioc_without_mapping: EpicsIOC +): add_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_pvi_info") controller = mocker.MagicMock() controller.path = [] @@ -331,17 +375,17 @@ def test_add_sub_controller_pvi_info(mocker: MockerFixture): child.path = ["Child"] controller.get_sub_controllers.return_value = {"d": child} - _add_sub_controller_pvi_info(DEVICE, controller) + ioc_without_mapping._add_sub_controller_pvi_info(DEVICE, controller) add_pvi_info.assert_called_once_with( f"{DEVICE}:Child:PVI", f"{DEVICE}:PVI", "child" ) -def test_add_attr_pvi_info(mocker: MockerFixture): +def test_add_attr_pvi_info(mocker: MockerFixture, ioc_without_mapping: EpicsIOC): record = mocker.MagicMock() - _add_attr_pvi_info(record, DEVICE, "attr", "r") + ioc_without_mapping._add_attr_pvi_info(record, DEVICE, "attr", "r") record.add_info.assert_called_once_with( "Q:group", @@ -390,9 +434,10 @@ def test_long_pv_names_discarded(mocker: MockerFixture): f"{DEVICE}:{short_pv_name}", always_update=True, on_update=mocker.ANY, + **DEFAULT_SCALAR_FIELD_ARGS, ) builder.longIn.assert_called_once_with( - f"{DEVICE}:{short_pv_name}_RBV", + f"{DEVICE}:{short_pv_name}_RBV", **DEFAULT_SCALAR_FIELD_ARGS ) long_pv_name = long_attr_name.title().replace("_", "") diff --git a/tests/test_attribute.py b/tests/test_attribute.py index eca1ffb80..5981869d9 100644 --- a/tests/test_attribute.py +++ b/tests/test_attribute.py @@ -1,14 +1,15 @@ from functools import partial +import numpy as np import pytest from fastcs.attributes import AttrR, AttrRW -from fastcs.datatypes import Int, String +from fastcs.datatypes import Int, String, WaveForm @pytest.mark.asyncio async def test_attributes(): - device = {"state": "Idle", "number": 1, "count": False} + device = {"state": "Idle", "number": 1, "count": False, "array": None} ui = {"state": "", "number": 0, "count": False} async def update_ui(value, key): @@ -17,17 +18,30 @@ async def update_ui(value, key): async def send(value, key): device[key] = value - async def device_add(): - device["number"] += 1 - attr_r = AttrR(String()) attr_r.set_update_callback(partial(update_ui, key="state")) await attr_r.set(device["state"]) assert ui["state"] == "Idle" - attr_rw = AttrRW(Int()) + attr_rw = AttrRW(Int(max=10)) attr_rw.set_process_callback(partial(send, key="number")) attr_rw.set_write_display_callback(partial(update_ui, key="number")) - await attr_rw.process(2) - assert device["number"] == 2 - assert ui["number"] == 2 + await attr_rw.process(10) + assert device["number"] == 10 + assert ui["number"] == 10 + with pytest.raises(ValueError): + await attr_rw.set(100) + + attr_rw = AttrRW(WaveForm(np.dtype("int32"), 10)) + attr_rw.set_process_callback(partial(send, key="array")) + await attr_rw.process(np.array(range(10), dtype="int32")) + assert np.array_equal(device["array"], np.array(range(10), dtype="int32")) + + with pytest.raises( + ValueError, match="Waveform length 11 is greater than maximum 10." + ): + await attr_rw.process(np.array(range(11), dtype="int32")) + with pytest.raises( + ValueError, match="Waveform dtype float64 does not match int32." + ): + await attr_rw.process(np.array(range(10), dtype="float64")) diff --git a/tests/test_backend.py b/tests/test_backend.py index 67ea12fe9..ae518059f 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,4 +1,4 @@ -from time import sleep +import asyncio import pytest @@ -16,7 +16,7 @@ async def init_task(self): self.init_task_called = True def _run(self): - pass + asyncio.run_coroutine_threadsafe(asyncio.sleep(0.3), self._loop) @pytest.mark.asyncio @@ -41,5 +41,7 @@ async def test_backend(controller): # Scan tasks should be running for _ in range(3): count = controller.count - sleep(0.05) + await asyncio.sleep(0.1) assert controller.count > count + + backend.stop_scan_tasks()