Skip to content

Commit 06b1462

Browse files
committed
added cast and initial_value methods to datatypes
1 parent 31b00d6 commit 06b1462

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

src/fastcs/attributes.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing import Any, Generic, Protocol, runtime_checkable
66

7-
from .datatypes import ATTRIBUTE_TYPES, AttrCallback, DataType, T, validate_value
7+
from .datatypes import ATTRIBUTE_TYPES, AttrCallback, DataType, T
88

99

1010
class AttrMode(Enum):
@@ -126,15 +126,17 @@ def __init__(
126126
allowed_values=allowed_values, # type: ignore
127127
description=description,
128128
)
129-
self._value: T = datatype.dtype() if initial_value is None else initial_value
129+
self._value: T = (
130+
datatype.initial_value if initial_value is None else initial_value
131+
)
130132
self._update_callback: AttrCallback[T] | None = None
131133
self._updater = handler
132134

133135
def get(self) -> T:
134136
return self._value
135137

136138
async def set(self, value: T) -> None:
137-
self._value = self._datatype.dtype(validate_value(self._datatype, value))
139+
self._value = self._datatype.cast(value)
138140

139141
if self._update_callback is not None:
140142
await self._update_callback(self._value)
@@ -177,11 +179,11 @@ async def process(self, value: T) -> None:
177179

178180
async def process_without_display_update(self, value: T) -> None:
179181
if self._process_callback is not None:
180-
await self._process_callback(self._datatype.dtype(value))
182+
await self._process_callback(self._datatype.cast(value))
181183

182184
async def update_display_without_process(self, value: T) -> None:
183185
if self._write_display_callback is not None:
184-
await self._write_display_callback(self._datatype.dtype(value))
186+
await self._write_display_callback(self._datatype.cast(value))
185187

186188
def set_process_callback(self, callback: AttrCallback[T] | None) -> None:
187189
self._process_callback = callback
@@ -221,6 +223,6 @@ def __init__(
221223
)
222224

223225
async def process(self, value: T) -> None:
224-
await self.set(validate_value(self._datatype, value))
226+
await self.set(value)
225227

226228
await super().process(value) # type: ignore

src/fastcs/datatypes.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import abstractmethod
44
from collections.abc import Awaitable, Callable
55
from dataclasses import dataclass
6-
from typing import Generic, TypeVar
6+
from typing import Any, Generic, TypeVar
77

88
T = TypeVar("T", int, float, bool, str)
99
ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore
@@ -21,6 +21,18 @@ class DataType(Generic[T]):
2121
def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars
2222
pass
2323

24+
@abstractmethod
25+
def cast(self, value: T) -> Any:
26+
"""Cast a value to a more primative datatype for `Attribute` push.
27+
28+
Also validate it against fields in the datatype.
29+
"""
30+
pass
31+
32+
@property
33+
def initial_value(self) -> T:
34+
return self.dtype()
35+
2436

2537
T_Numerical = TypeVar("T_Numerical", int, float)
2638

@@ -33,6 +45,13 @@ class _Numerical(DataType[T_Numerical]):
3345
min_alarm: int | None = None
3446
max_alarm: int | None = None
3547

48+
def cast(self, value: T_Numerical) -> T_Numerical:
49+
if self.min is not None and value < self.min:
50+
raise ValueError(f"Value {value} is less than minimum {self.min}")
51+
if self.max is not None and value > self.max:
52+
raise ValueError(f"Value {value} is greater than maximum {self.max}")
53+
return value
54+
3655

3756
@dataclass(frozen=True)
3857
class Int(_Numerical[int]):
@@ -65,6 +84,9 @@ class Bool(DataType[bool]):
6584
def dtype(self) -> type[bool]:
6685
return bool
6786

87+
def cast(self, value: bool) -> bool:
88+
return value
89+
6890

6991
@dataclass(frozen=True)
7092
class String(DataType[str]):
@@ -74,14 +96,5 @@ class String(DataType[str]):
7496
def dtype(self) -> type[str]:
7597
return str
7698

77-
78-
def validate_value(datatype: DataType[T], value: T) -> T:
79-
"""Validate a value against a datatype."""
80-
81-
if isinstance(datatype, (Int | Float)):
82-
assert isinstance(value, (int | float)), f"Value {value} is not a number"
83-
if datatype.min is not None and value < datatype.min:
84-
raise ValueError(f"Value {value} is less than minimum {datatype.min}")
85-
if datatype.max is not None and value > datatype.max:
86-
raise ValueError(f"Value {value} is greater than maximum {datatype.max}")
87-
return value
99+
def cast(self, value: str) -> str:
100+
return value

0 commit comments

Comments
 (0)