Skip to content

Commit 7cb7dd8

Browse files
committed
added WaveForm type
1 parent 3b915ca commit 7cb7dd8

File tree

4 files changed

+111
-158
lines changed

4 files changed

+111
-158
lines changed

src/fastcs/attributes.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing import Any, Generic, Protocol, runtime_checkable
66

7-
from .datatypes import ATTRIBUTE_TYPES, AttrCallback, DataType, T, validate_value
7+
from .datatypes import ATTRIBUTE_TYPES, AttrCallback, DataType, T
88

99

1010
class AttrMode(Enum):
@@ -124,15 +124,17 @@ def __init__(
124124
allowed_values=allowed_values, # type: ignore
125125
description=description,
126126
)
127-
self._value: T = datatype.dtype() if initial_value is None else initial_value
127+
self._value: T = (
128+
datatype.initial_value if initial_value is None else initial_value
129+
)
128130
self._update_callback: AttrCallback[T] | None = None
129131
self._updater = handler
130132

131133
def get(self) -> T:
132134
return self._value
133135

134136
async def set(self, value: T) -> None:
135-
self._value = self._datatype.dtype(validate_value(self._datatype, value))
137+
self._value = self._datatype.cast(value)
136138

137139
if self._update_callback is not None:
138140
await self._update_callback(self._value)
@@ -175,11 +177,11 @@ async def process(self, value: T) -> None:
175177

176178
async def process_without_display_update(self, value: T) -> None:
177179
if self._process_callback is not None:
178-
await self._process_callback(self._datatype.dtype(value))
180+
await self._process_callback(self._datatype.cast(value))
179181

180182
async def update_display_without_process(self, value: T) -> None:
181183
if self._write_display_callback is not None:
182-
await self._write_display_callback(self._datatype.dtype(value))
184+
await self._write_display_callback(self._datatype.cast(value))
183185

184186
def set_process_callback(self, callback: AttrCallback[T] | None) -> None:
185187
self._process_callback = callback
@@ -219,6 +221,6 @@ def __init__(
219221
)
220222

221223
async def process(self, value: T) -> None:
222-
await self.set(validate_value(self._datatype, value))
224+
await self.set(value)
223225

224226
await super().process(value) # type: ignore

src/fastcs/backends/epics/ioc.py

Lines changed: 2 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import Callable
22
from dataclasses import asdict, dataclass
33
from types import MethodType
4-
from typing import Any, Literal, cast
4+
from typing import Any, Literal
55

66
import numpy as np
77
from softioc import builder, fields, softioc
@@ -18,9 +18,8 @@
1818
enum_index_to_value,
1919
enum_value_to_index,
2020
)
21-
from epicsdbbuilder import RecordName
2221
from fastcs.controller import BaseController
23-
from fastcs.datatypes import Bool, DataType, Float, Int, String, T, Table, WaveForm
22+
from fastcs.datatypes import Bool, DataType, Float, Int, String, T, WaveForm
2423
from fastcs.exceptions import FastCSException
2524
from fastcs.mapping import Mapping
2625

@@ -302,114 +301,6 @@ def _add_attr_pvi_info(
302301
},
303302
)
304303

305-
def create_table(self, top_level_pv: str, datatype: np.dtype):
306-
pva_table_name = RecordName(top_level_pv)
307-
308-
columns = datatype.descr
309-
if not columns or not all(
310-
isinstance(column, tuple) and len(column) == 2 for column in columns
311-
):
312-
raise FastCSException("Table datatype must have a structured dtype.")
313-
314-
columns = cast(list[tuple[str, np.dtype]], columns)
315-
316-
labels_record: RecordWrapper = builder.WaveformOut(
317-
top_level_pv + ":LABELS",
318-
initial_value=np.array([name.encode() for (name, _) in columns]),
319-
)
320-
321-
labels_record.add_info(
322-
"Q:group",
323-
{
324-
pva_table_name: {
325-
"+id": "epics:nt/NTTable:1.0",
326-
"labels": {"+type": "plain", "+channel": "VAL"},
327-
}
328-
},
329-
)
330-
331-
pv_rec = builder.longStringIn(
332-
top_level_pv + ":PV",
333-
initial_value=pva_table_name,
334-
)
335-
block, field = top_level_pv.rsplit(":", maxsplit=1)
336-
_add_pvi_info(field, block, field.lower())
337-
338-
self.table_fields_records = OrderedDict(
339-
{
340-
k: TableFieldRecordContainer(v, None)
341-
for k, v in field_info.fields.items()
342-
}
343-
)
344-
self.all_values_dict = all_values_dict
345-
346-
pvi_table_name = epics_to_pvi_name(table_name)
347-
348-
# The PVI group to put all records into
349-
pvi_group = PviGroup.PARAMETERS
350-
Pvi.add_pvi_info(
351-
table_name,
352-
pvi_group,
353-
SignalRW(
354-
name=pvi_table_name,
355-
write_pv=f"{Pvi.record_prefix}:{table_name}",
356-
write_widget=TableWrite(widgets=[]),
357-
),
358-
)
359-
360-
# Note that the table_updater's table_fields are guaranteed sorted in bit order,
361-
# unlike field_info's fields. This means the record dict inside the table
362-
# updater are also in the same bit order.
363-
value = all_values_dict[table_name]
364-
assert isinstance(value, list)
365-
field_data = words_to_table(value, field_info)
366-
367-
for i, (field_name, field_record_container) in enumerate(
368-
self.table_fields_records.items()
369-
):
370-
field_details = field_record_container.field
371-
372-
full_name = table_name + ":" + field_name
373-
full_name = EpicsName(full_name)
374-
description = trim_description(field_details.description, full_name)
375-
376-
waveform_val = self._construct_waveform_val(
377-
field_data, field_name, field_details
378-
)
379-
380-
field_record: RecordWrapper = builder.WaveformOut(
381-
full_name,
382-
DESC=description,
383-
validate=self.validate_waveform,
384-
initial_value=waveform_val,
385-
length=field_info.max_length,
386-
)
387-
388-
field_pva_info = {
389-
"+type": "plain",
390-
"+channel": "VAL",
391-
"+putorder": i + 1,
392-
"+trigger": "",
393-
}
394-
395-
pva_info = {f"value.{field_name.lower()}": field_pva_info}
396-
397-
# For the last column in the table
398-
if i == len(self.table_fields_records) - 1:
399-
# Trigger a monitor update
400-
field_pva_info["+trigger"] = "*"
401-
# Add metadata
402-
pva_info[""] = {"+type": "meta", "+channel": "VAL"}
403-
404-
field_record.add_info(
405-
"Q:group",
406-
{pva_table_name: pva_info},
407-
)
408-
409-
field_record_container.record_info = RecordInfo(lambda x: x, None, False)
410-
411-
field_record_container.record_info.add_record(field_record)
412-
413304

414305
def _add_pvi_info(
415306
pvi: str,
@@ -498,8 +389,6 @@ def datatype_updater(datatype: DataType):
498389
return builder.WaveformIn(
499390
pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields
500391
)
501-
case Table(numpy_datatype):
502-
return create_table(pv, numpy_datatype)
503392
case _:
504393
raise FastCSException(
505394
f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}"

src/fastcs/datatypes.py

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

3-
import numpy as np
4-
from numpy import typing as npt
53
from abc import abstractmethod
6-
from collections.abc import Awaitable, Callable
4+
from collections.abc import Awaitable, Callable, Sequence
75
from dataclasses import dataclass
8-
from typing import Generic, TypeVar
6+
from typing import Any, Generic, Literal, TypeVar
7+
8+
import numpy as np
99

10-
T = TypeVar("T", int, float, bool, str, npt.ArrayLike)
10+
T = TypeVar("T", int, float, bool, str, np.ndarray) # type: ignore
1111
ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore
1212

1313

@@ -20,7 +20,20 @@ class DataType(Generic[T]):
2020

2121
@property
2222
@abstractmethod
23-
def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars
23+
def dtype(
24+
self,
25+
) -> type[T]: # Using property due to lack of Generic ClassVars
26+
pass
27+
28+
@property
29+
@abstractmethod
30+
def initial_value(self) -> T:
31+
"""Return an initial value for the datatype."""
32+
pass
33+
34+
@abstractmethod
35+
def cast(self, value: Any) -> T:
36+
"""Cast a value to the datatype to put to the backend."""
2437
pass
2538

2639

@@ -38,6 +51,17 @@ class Int(DataType[int]):
3851
def dtype(self) -> type[int]:
3952
return int
4053

54+
@property
55+
def initial_value(self) -> Literal[0]:
56+
return 0
57+
58+
def cast(self, value: Any) -> int:
59+
if self.min is not None and value < self.min:
60+
raise ValueError(f"Value {value} is less than minimum {self.min}")
61+
if self.max is not None and value > self.max:
62+
raise ValueError(f"Value {value} is greater than maximum {self.max}")
63+
return int(value)
64+
4165

4266
@dataclass(frozen=True)
4367
class Float(DataType[float]):
@@ -54,6 +78,17 @@ class Float(DataType[float]):
5478
def dtype(self) -> type[float]:
5579
return float
5680

81+
@property
82+
def intial_value(self) -> float:
83+
return 0.0
84+
85+
def cast(self, value: Any) -> float:
86+
if self.min is not None and value < self.min:
87+
raise ValueError(f"Value {value} is less than minimum {self.min}")
88+
if self.max is not None and value > self.max:
89+
raise ValueError(f"Value {value} is greater than maximum {self.max}")
90+
return float(value)
91+
5792

5893
@dataclass(frozen=True)
5994
class Bool(DataType[bool]):
@@ -66,6 +101,13 @@ class Bool(DataType[bool]):
66101
def dtype(self) -> type[bool]:
67102
return bool
68103

104+
@property
105+
def intial_value(self) -> Literal[False]:
106+
return False
107+
108+
def cast(self, value: Any) -> bool:
109+
return bool(value)
110+
69111

70112
@dataclass(frozen=True)
71113
class String(DataType[str]):
@@ -75,39 +117,45 @@ class String(DataType[str]):
75117
def dtype(self) -> type[str]:
76118
return str
77119

120+
@property
121+
def intial_value(self) -> Literal[""]:
122+
return ""
78123

79-
@dataclass(frozen=True)
80-
class WaveForm(DataType[npt.ArrayLike]):
81-
"""DataType for a waveform"""
124+
def cast(self, value: Any) -> str:
125+
return str(value)
82126

83-
length: int | None = None
84127

85-
@property
86-
def dtype(self) -> type[npt.ArrayLike]:
87-
return np.ndarray
128+
DEFAULT_WAVEFORM_LENGTH = 20000
88129

89130

90131
@dataclass(frozen=True)
91-
class Table(DataType[npt.ArrayLike]):
92-
"""`DataType` mapping to a dictionary of numpy arrays.
93-
94-
Values should be a dictionary of column name to an `ArrayLike` of columns.
132+
class WaveForm(DataType[np.ndarray]):
133+
"""
134+
DataType for a waveform, values are of the numpy `datatype`
95135
"""
96136

97-
numpy_datatype: npt.DTypeLike
137+
numpy_datatype: np.dtype
138+
length: int = DEFAULT_WAVEFORM_LENGTH
98139

99140
@property
100-
def dtype(self) -> type[npt.ArrayLike]:
141+
def dtype(self) -> type[np.ndarray]:
101142
return np.ndarray
102143

103-
104-
def validate_value(datatype: DataType[T], value: T) -> T:
105-
"""Validate a value against a datatype."""
106-
107-
if isinstance(datatype, (Int | Float)):
108-
assert isinstance(value, (int | float)), f"Value {value} is not a number"
109-
if datatype.min is not None and value < datatype.min:
110-
raise ValueError(f"Value {value} is less than minimum {datatype.min}")
111-
if datatype.max is not None and value > datatype.max:
112-
raise ValueError(f"Value {value} is greater than maximum {datatype.max}")
113-
return value
144+
@property
145+
def initial_value(self) -> np.ndarray:
146+
return np.ndarray(self.length, dtype=self.numpy_datatype)
147+
148+
def cast(self, value: Sequence | np.ndarray) -> np.ndarray:
149+
if len(value) > self.length:
150+
raise ValueError(
151+
f"Waveform length {len(value)} is greater than maximum {self.length}."
152+
)
153+
if isinstance(value, np.ndarray):
154+
if value.dtype != self.numpy_datatype:
155+
raise ValueError(
156+
f"Waveform dtype {value.dtype} does not "
157+
f"match {self.numpy_datatype}."
158+
)
159+
return value
160+
else:
161+
return np.array(value, dtype=self.numpy_datatype)

0 commit comments

Comments
 (0)