Skip to content

Commit 49e799e

Browse files
authored
Correctly implement DType and DType_T (#285)
1 parent 40aa634 commit 49e799e

File tree

12 files changed

+46
-54
lines changed

12 files changed

+46
-54
lines changed

src/fastcs/attributes/attribute_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ async def send(self, attr: AttrW[DType_T, AttributeIORefT], value: DType_T) -> N
3737
raise NotImplementedError()
3838

3939

40-
AnyAttributeIO = AttributeIO[DType_T, AttributeIORef]
40+
AnyAttributeIO = AttributeIO[Any]

src/fastcs/controllers/base_controller.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,7 @@
55
from copy import deepcopy
66
from typing import _GenericAlias, get_args, get_origin, get_type_hints # type: ignore
77

8-
from fastcs.attributes import (
9-
Attribute,
10-
AttributeIO,
11-
AttributeIORefT,
12-
AttrR,
13-
AttrW,
14-
HintedAttribute,
15-
)
16-
from fastcs.datatypes import DType_T
8+
from fastcs.attributes import AnyAttributeIO, Attribute, AttrR, AttrW, HintedAttribute
179
from fastcs.logging import bind_logger
1810
from fastcs.tracer import Tracer
1911

@@ -41,7 +33,7 @@ def __init__(
4133
self,
4234
path: list[str] | None = None,
4335
description: str | None = None,
44-
ios: Sequence[AttributeIO[DType_T, AttributeIORefT]] | None = None,
36+
ios: Sequence[AnyAttributeIO] | None = None,
4537
) -> None:
4638
super().__init__()
4739

@@ -125,7 +117,7 @@ class method and a controller instance, so that it can be called from any
125117
elif isinstance(attr, UnboundScan | UnboundCommand):
126118
setattr(self, attr_name, attr.bind(self))
127119

128-
def _validate_io(self, ios: Sequence[AttributeIO[DType_T, AttributeIORefT]]):
120+
def _validate_io(self, ios: Sequence[AnyAttributeIO]):
129121
"""Validate that there is exactly one AttributeIO class registered to the
130122
controller for each type of AttributeIORef belonging to the attributes of the
131123
controller"""

src/fastcs/controllers/controller.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from collections.abc import Sequence
22

3-
from fastcs.attributes import AttributeIO, AttributeIORefT
3+
from fastcs.attributes import AnyAttributeIO
44
from fastcs.controllers.base_controller import BaseController
5-
from fastcs.datatypes import DType_T
65

76

87
class Controller(BaseController):
@@ -11,7 +10,7 @@ class Controller(BaseController):
1110
def __init__(
1211
self,
1312
description: str | None = None,
14-
ios: Sequence[AttributeIO[DType_T, AttributeIORefT]] | None = None,
13+
ios: Sequence[AnyAttributeIO] | None = None,
1514
) -> None:
1615
super().__init__(description=description, ios=ios)
1716

src/fastcs/controllers/controller_vector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from collections.abc import Iterator, Mapping, MutableMapping, Sequence
22

3-
from fastcs.attributes import AttributeIO, AttributeIORefT
3+
from fastcs.attributes import AnyAttributeIO
44
from fastcs.controllers.base_controller import BaseController
55
from fastcs.controllers.controller import Controller
6-
from fastcs.datatypes import DType_T
76

87

98
class ControllerVector(MutableMapping[int, Controller], BaseController):
@@ -18,7 +17,7 @@ def __init__(
1817
self,
1918
children: Mapping[int, Controller],
2019
description: str | None = None,
21-
ios: Sequence[AttributeIO[DType_T, AttributeIORefT]] | None = None,
20+
ios: Sequence[AnyAttributeIO] | None = None,
2221
) -> None:
2322
super().__init__(description=description, ios=ios)
2423
self._children: dict[int, Controller] = {}

src/fastcs/datatypes/datatype.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,11 @@
1313
| enum.Enum # Enum
1414
| np.ndarray # Waveform / Table
1515
)
16-
17-
DType_T = TypeVar(
18-
"DType_T",
19-
int, # Int
20-
float, # Float
21-
bool, # Bool
22-
str, # String
23-
enum.Enum, # Enum
24-
np.ndarray, # Waveform / Table
25-
)
2616
"""A builtin (or numpy) type supported by a corresponding FastCS Attribute DataType"""
2717

18+
DType_T = TypeVar("DType_T", bound=DType)
19+
"""A TypeVar of `DType` for use in generic classes and functions"""
20+
2821

2922
@dataclass(frozen=True)
3023
class DataType(Generic[DType_T]):

src/fastcs/transports/epics/ca/util.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import enum
12
from dataclasses import asdict
23
from typing import Any
34

45
from softioc import builder
56

67
from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW
7-
from fastcs.datatypes import Bool, DataType, DType_T, Enum, Float, Int, String, Waveform
8+
from fastcs.datatypes import Bool, DType_T, Enum, Float, Int, String, Waveform
9+
from fastcs.datatypes.datatype import DataType
810
from fastcs.exceptions import FastCSError
911

1012
_MBB_FIELD_PREFIXES = (
@@ -31,7 +33,7 @@
3133
MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES)
3234

3335

34-
EPICS_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, Waveform)
36+
EPICS_ALLOWED_DATATYPES = (Bool, Enum, Float, Int, String, Waveform)
3537
DEFAULT_STRING_WAVEFORM_LENGTH = 256
3638

3739
DATATYPE_FIELD_TO_RECORD_FIELD = {
@@ -44,9 +46,7 @@
4446
}
4547

4648

47-
def record_metadata_from_attribute(
48-
attribute: Attribute[DType_T],
49-
) -> dict[str, Any]:
49+
def record_metadata_from_attribute(attribute: Attribute[DType_T]) -> dict[str, Any]:
5050
"""Converts attributes on the `Attribute` to the
5151
field name/value in the record metadata."""
5252
metadata: dict[str, Any] = {"DESC": attribute.description}
@@ -62,7 +62,7 @@ def record_metadata_from_attribute(
6262

6363

6464
def record_metadata_from_datatype(
65-
datatype: DataType[DType_T], out_record: bool = False
65+
datatype: DataType[Any], out_record: bool = False
6666
) -> dict[str, str]:
6767
"""Converts attributes on the `DataType` to the
6868
field name/value in the record metadata."""
@@ -123,9 +123,14 @@ def cast_from_epics_type(datatype: DataType[DType_T], value: object) -> DType_T:
123123
raise ValueError(f"Invalid bool value from EPICS record {value}")
124124
case Enum():
125125
if len(datatype.members) <= MBB_MAX_CHOICES:
126+
assert isinstance(value, int), "Got non-integer value for Enum"
126127
return datatype.validate(datatype.members[value])
127128
else: # enum backed by string record
128-
return datatype.validate(datatype.enum_cls[value])
129+
assert isinstance(value, str), "Got non-string value for long Enum"
130+
# python typing can't narrow the nested generic enum_cls
131+
assert issubclass(datatype.enum_cls, enum.Enum), "Invalid Enum.enum_cls"
132+
enum_member = datatype.enum_cls[value]
133+
return datatype.validate(enum_member)
129134
case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES):
130135
return datatype.validate(value) # type: ignore
131136
case _:

src/fastcs/transports/epics/pva/types.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from p4p.nt import NTEnum, NTNDArray, NTScalar, NTTable
88

99
from fastcs.attributes import Attribute, AttrR, AttrW
10-
from fastcs.datatypes import Bool, DType_T, Enum, Float, Int, String, Table, Waveform
10+
from fastcs.datatypes import Bool, DType, Enum, Float, Int, String, Table, Waveform
11+
from fastcs.datatypes.datatype import DType_T
1112

1213
P4P_ALLOWED_DATATYPES = (Int, Float, String, Bool, Enum, Waveform, Table)
1314

@@ -90,7 +91,9 @@ def cast_from_p4p_value(attribute: Attribute[DType_T], value: object) -> DType_T
9091
"""Converts from a p4p value to a FastCS `Attribute` value."""
9192
match attribute.datatype:
9293
case Enum():
93-
return attribute.datatype.validate(attribute.datatype.members[value.index])
94+
assert hasattr(value, "index"), "Got non-enum p4p.Value for Enum DataType"
95+
index: int = value.index # pyright: ignore[reportAttributeAccessIssue]
96+
return attribute.datatype.validate(attribute.datatype.members[index])
9497
case Waveform(shape=shape):
9598
# p4p sends a flattened array
9699
assert value.shape == (math.prod(shape),)
@@ -154,7 +157,7 @@ def p4p_display(attribute: Attribute) -> dict:
154157
return {}
155158

156159

157-
def _p4p_check_numeric_for_alarm_states(datatype: Int | Float, value: DType_T) -> dict:
160+
def _p4p_check_numeric_for_alarm_states(datatype: Int | Float, value: DType) -> dict:
158161
low = None if datatype.min_alarm is None else value < datatype.min_alarm # type: ignore
159162
high = None if datatype.max_alarm is None else value > datatype.max_alarm # type: ignore
160163
severity = (

src/fastcs/transports/graphql/graphql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ async def _dynamic_f(value):
135135

136136
def _wrap_attr_get(
137137
attr_name: str, attribute: AttrR[DType_T]
138-
) -> Callable[[], Coroutine[Any, Any, Any]]:
138+
) -> Callable[[], Coroutine[Any, Any, DType_T]]:
139139
"""Wrap an attribute in a function with annotations for strawberry"""
140140

141-
async def _dynamic_f() -> Any:
141+
async def _dynamic_f() -> DType_T:
142142
return attribute.get()
143143

144144
_dynamic_f.__name__ = attr_name

src/fastcs/transports/rest/rest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def _get_response_body(attribute: AttrR[DType_T]):
9292

9393
def _wrap_attr_get(
9494
attribute: AttrR[DType_T],
95-
) -> Callable[[], Coroutine[Any, Any, Any]]:
96-
async def attr_get() -> Any: # Must be any as response_model is set
97-
value = attribute.get() # type: ignore
95+
) -> Callable[[], Coroutine[Any, Any, dict[str, object]]]:
96+
async def attr_get() -> dict[str, object]:
97+
value = attribute.get()
9898
return {"value": cast_to_rest_type(attribute.datatype, value)}
9999

100100
return attr_get

src/fastcs/transports/rest/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
REST_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String)
66

77

8-
def convert_datatype(datatype: DataType[DType_T]) -> type:
8+
def convert_datatype(datatype: DataType[DType_T]) -> type[DType_T]:
99
"""Converts a datatype to a rest serialisable type."""
1010
match datatype:
1111
case Waveform():

0 commit comments

Comments
 (0)