Skip to content

Commit dda9e45

Browse files
committed
Refactor enum logic
Add some type ignores for bugs in mypy. These will be removed when moved to pyright.
1 parent d254dd3 commit dda9e45

File tree

4 files changed

+245
-58
lines changed

4 files changed

+245
-58
lines changed

src/fastcs/backends/epics/ioc.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010
from fastcs.attributes import AttrR, AttrRW, AttrW
1111
from fastcs.backends.epics.util import (
12-
MBB_MAX_CHOICES,
1312
MBB_STATE_FIELDS,
14-
convert_if_enum,
13+
attr_is_enum,
14+
enum_index_to_value,
15+
enum_value_to_index,
1516
)
1617
from fastcs.controller import BaseController
1718
from fastcs.datatypes import Bool, Float, Int, String, T
@@ -155,23 +156,25 @@ def _create_and_link_attribute_pvs(pv_prefix: str, mapping: Mapping) -> None:
155156
def _create_and_link_read_pv(
156157
pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T]
157158
) -> None:
158-
record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute)
159+
if attr_is_enum(attribute):
159160

160-
_add_attr_pvi_info(record, pv_prefix, attr_name, "r")
161+
async def async_record_set(value: T):
162+
record.set(enum_value_to_index(attribute, value))
163+
else:
161164

162-
async def async_record_set(value: T):
163-
record.set(convert_if_enum(attribute, value))
165+
async def async_record_set(value: T): # type: ignore
166+
record.set(value)
167+
168+
record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute)
169+
_add_attr_pvi_info(record, pv_prefix, attr_name, "r")
164170

165171
attribute.set_update_callback(async_record_set)
166172

167173

168174
def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper:
169-
if (
170-
isinstance(attribute.datatype, String)
171-
and attribute.allowed_values is not None
172-
and len(attribute.allowed_values) <= MBB_MAX_CHOICES
173-
):
174-
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False))
175+
if attr_is_enum(attribute):
176+
# https://github.com/python/mypy/issues/16789
177+
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) # type: ignore
175178
return builder.mbbIn(pv, **state_keys)
176179

177180
match attribute.datatype:
@@ -192,40 +195,35 @@ def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper:
192195
def _create_and_link_write_pv(
193196
pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T]
194197
) -> None:
195-
async def on_update(value):
196-
if (
197-
isinstance(attribute.datatype, String)
198-
and isinstance(value, int)
199-
and attribute.allowed_values is not None
200-
):
201-
try:
202-
value = attribute.allowed_values[value]
203-
except IndexError:
204-
raise IndexError(
205-
f"Invalid index {value}, allowed values: {attribute.allowed_values}"
206-
) from None
207-
208-
await attribute.process_without_display_update(value)
198+
if attr_is_enum(attribute):
199+
200+
async def on_update(value):
201+
await attribute.process_without_display_update(
202+
enum_index_to_value(attribute, value)
203+
)
204+
205+
async def async_write_display(value: T):
206+
record.set(enum_value_to_index(attribute, value), process=False)
207+
208+
else:
209+
210+
async def on_update(value):
211+
await attribute.process_without_display_update(value)
212+
213+
async def async_write_display(value: T): # type: ignore
214+
record.set(value, process=False)
209215

210216
record = _get_output_record(
211217
f"{pv_prefix}:{pv_name}", attribute, on_update=on_update
212218
)
213-
214219
_add_attr_pvi_info(record, pv_prefix, attr_name, "w")
215220

216-
async def async_record_set(value: T):
217-
record.set(convert_if_enum(attribute, value), process=False)
218-
219-
attribute.set_write_display_callback(async_record_set)
221+
attribute.set_write_display_callback(async_write_display)
220222

221223

222224
def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any:
223-
if (
224-
isinstance(attribute.datatype, String)
225-
and attribute.allowed_values is not None
226-
and len(attribute.allowed_values) <= MBB_MAX_CHOICES
227-
):
228-
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False))
225+
if attr_is_enum(attribute):
226+
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) # type: ignore
229227
return builder.mbbOut(pv, always_update=True, on_update=on_update, **state_keys)
230228

231229
match attribute.datatype:

src/fastcs/backends/epics/util.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,75 @@
2525
MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES)
2626

2727

28-
def convert_if_enum(attribute: Attribute[T], value: T) -> T | int:
29-
"""Check if `attribute` is a string enum and if so convert `value` to index of enum.
28+
def attr_is_enum(attribute: Attribute) -> bool:
29+
"""Check if the `Attribute` has a `String` datatype and has `allowed_values` set.
3030
3131
Args:
32-
`attribute`: The attribute to be set
33-
`value`: The value
32+
attribute: The `Attribute` to check
3433
3534
Returns:
36-
The index of the `value` if the `attribute` is an enum, else `value`
37-
38-
Raises:
39-
ValueError: If `attribute` is an enum and `value` is not in its allowed values
35+
`True` if `Attribute` is an enum, else `False`
4036
4137
"""
4238
match attribute:
4339
case Attribute(
4440
datatype=String(), allowed_values=allowed_values
4541
) if allowed_values is not None and len(allowed_values) <= MBB_MAX_CHOICES:
46-
if value in allowed_values:
47-
return allowed_values.index(value)
48-
else:
49-
raise ValueError(f"'{value}' not in allowed values {allowed_values}")
42+
return True
5043
case _:
51-
return value
44+
return False
45+
46+
47+
def enum_value_to_index(attribute: Attribute[T], value: T) -> int:
48+
"""Convert the given value to the index within the allowed_values of the Attribute
49+
50+
Args:
51+
`attribute`: The attribute
52+
`value`: The value to convert
53+
54+
Returns:
55+
The index of the `value`
56+
57+
Raises:
58+
ValueError: If `attribute` has no allowed values or `value` is not a valid
59+
option
60+
61+
"""
62+
if attribute.allowed_values is None:
63+
raise ValueError(
64+
"Cannot convert value to index for Attribute without allowed values"
65+
)
66+
67+
try:
68+
return attribute.allowed_values.index(value)
69+
except ValueError:
70+
raise ValueError(
71+
f"{value} not in allowed values of {attribute}: {attribute.allowed_values}"
72+
) from None
73+
74+
75+
def enum_index_to_value(attribute: Attribute[T], index: int) -> T:
76+
"""Lookup the value from the allowed_values of an attribute at the given index.
77+
78+
Parameters:
79+
attribute: The `Attribute` to lookup the index from
80+
index: The index of the value to retrieve
81+
82+
Returns:
83+
The value at the specified index in the allowed values list.
84+
85+
Raises:
86+
IndexError: If the index is out of bounds
87+
88+
"""
89+
if attribute.allowed_values is None:
90+
raise ValueError(
91+
"Cannot lookup value by index for Attribute without allowed values"
92+
)
93+
94+
try:
95+
return attribute.allowed_values[index]
96+
except IndexError:
97+
raise IndexError(
98+
f"Invalid index {index} into allowed values: {attribute.allowed_values}"
99+
) from None

tests/backends/epics/test_ioc.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
_add_attr_pvi_info,
1111
_add_pvi_info,
1212
_add_sub_controller_pvi_info,
13+
_create_and_link_read_pv,
14+
_create_and_link_write_pv,
1315
_get_input_record,
1416
_get_output_record,
1517
)
@@ -25,6 +27,54 @@
2527
ONOFF_STATES = {"ZRST": "disabled", "ONST": "enabled"}
2628

2729

30+
@pytest.mark.asyncio
31+
async def test_create_and_link_read_pv(mocker: MockerFixture):
32+
get_input_record = mocker.patch("fastcs.backends.epics.ioc._get_input_record")
33+
add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info")
34+
attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum")
35+
record = get_input_record.return_value
36+
37+
attribute = mocker.MagicMock()
38+
39+
attr_is_enum.return_value = False
40+
_create_and_link_read_pv("PREFIX", "PV", "attr", attribute)
41+
42+
get_input_record.assert_called_once_with("PREFIX:PV", attribute)
43+
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r")
44+
45+
# Extract the callback generated and set in the function and call it
46+
attribute.set_update_callback.assert_called_once_with(mocker.ANY)
47+
record_set_callback = attribute.set_update_callback.call_args[0][0]
48+
await record_set_callback(1)
49+
50+
record.set.assert_called_once_with(1)
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_create_and_link_read_pv_enum(mocker: MockerFixture):
55+
get_input_record = mocker.patch("fastcs.backends.epics.ioc._get_input_record")
56+
add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info")
57+
attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum")
58+
record = get_input_record.return_value
59+
enum_value_to_index = mocker.patch("fastcs.backends.epics.ioc.enum_value_to_index")
60+
61+
attribute = mocker.MagicMock()
62+
63+
attr_is_enum.return_value = True
64+
_create_and_link_read_pv("PREFIX", "PV", "attr", attribute)
65+
66+
get_input_record.assert_called_once_with("PREFIX:PV", attribute)
67+
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r")
68+
69+
# Extract the callback generated and set in the function and call it
70+
attribute.set_update_callback.assert_called_once_with(mocker.ANY)
71+
record_set_callback = attribute.set_update_callback.call_args[0][0]
72+
await record_set_callback(1)
73+
74+
enum_value_to_index.assert_called_once_with(attribute, 1)
75+
record.set.assert_called_once_with(enum_value_to_index.return_value)
76+
77+
2878
@pytest.mark.parametrize(
2979
"attribute,record_type,kwargs",
3080
(
@@ -57,6 +107,75 @@ def test_get_input_record_raises(mocker: MockerFixture):
57107
_get_input_record("PV", mocker.MagicMock())
58108

59109

110+
@pytest.mark.asyncio
111+
async def test_create_and_link_write_pv(mocker: MockerFixture):
112+
get_output_record = mocker.patch("fastcs.backends.epics.ioc._get_output_record")
113+
add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info")
114+
attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum")
115+
record = get_output_record.return_value
116+
117+
attribute = mocker.MagicMock()
118+
attribute.process_without_display_update = mocker.AsyncMock()
119+
120+
attr_is_enum.return_value = False
121+
_create_and_link_write_pv("PREFIX", "PV", "attr", attribute)
122+
123+
get_output_record.assert_called_once_with(
124+
"PREFIX:PV", attribute, on_update=mocker.ANY
125+
)
126+
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w")
127+
128+
# Extract the write update callback generated and set in the function and call it
129+
attribute.set_write_display_callback.assert_called_once_with(mocker.ANY)
130+
write_display_callback = attribute.set_write_display_callback.call_args[0][0]
131+
await write_display_callback(1)
132+
133+
record.set.assert_called_once_with(1, process=False)
134+
135+
# Extract the on update callback generated and set in the function and call it
136+
on_update_callback = get_output_record.call_args[1]["on_update"]
137+
await on_update_callback(1)
138+
139+
attribute.process_without_display_update.assert_called_once_with(1)
140+
141+
142+
@pytest.mark.asyncio
143+
async def test_create_and_link_write_pv_enum(mocker: MockerFixture):
144+
get_output_record = mocker.patch("fastcs.backends.epics.ioc._get_output_record")
145+
add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info")
146+
attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum")
147+
enum_value_to_index = mocker.patch("fastcs.backends.epics.ioc.enum_value_to_index")
148+
enum_index_to_value = mocker.patch("fastcs.backends.epics.ioc.enum_index_to_value")
149+
record = get_output_record.return_value
150+
151+
attribute = mocker.MagicMock()
152+
attribute.process_without_display_update = mocker.AsyncMock()
153+
154+
attr_is_enum.return_value = True
155+
_create_and_link_write_pv("PREFIX", "PV", "attr", attribute)
156+
157+
get_output_record.assert_called_once_with(
158+
"PREFIX:PV", attribute, on_update=mocker.ANY
159+
)
160+
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w")
161+
162+
# Extract the write update callback generated and set in the function and call it
163+
attribute.set_write_display_callback.assert_called_once_with(mocker.ANY)
164+
write_display_callback = attribute.set_write_display_callback.call_args[0][0]
165+
await write_display_callback(1)
166+
167+
enum_value_to_index.assert_called_once_with(attribute, 1)
168+
record.set.assert_called_once_with(enum_value_to_index.return_value, process=False)
169+
170+
# Extract the on update callback generated and set in the function and call it
171+
on_update_callback = get_output_record.call_args[1]["on_update"]
172+
await on_update_callback(1)
173+
174+
attribute.process_without_display_update.assert_called_once_with(
175+
enum_index_to_value.return_value
176+
)
177+
178+
60179
@pytest.mark.parametrize(
61180
"attribute,record_type,kwargs",
62181
(

tests/backends/epics/test_util.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,39 @@
11
import pytest
22

33
from fastcs.attributes import AttrR
4-
from fastcs.backends.epics.util import convert_if_enum
4+
from fastcs.backends.epics.util import (
5+
attr_is_enum,
6+
enum_index_to_value,
7+
enum_value_to_index,
8+
)
59
from fastcs.datatypes import String
610

711

8-
def test_convert_if_enum():
9-
string_attr = AttrR(String())
10-
enum_attr = AttrR(String(), allowed_values=["disabled", "enabled"])
12+
def test_attr_is_enum():
13+
assert not attr_is_enum(AttrR(String()))
14+
assert attr_is_enum(AttrR(String(), allowed_values=["disabled", "enabled"]))
1115

12-
assert convert_if_enum(string_attr, "enabled") == "enabled"
1316

14-
assert convert_if_enum(enum_attr, "enabled") == 1
17+
def test_enum_index_to_value():
18+
"""Test enum_index_to_value."""
19+
attribute = AttrR(String(), allowed_values=["disabled", "enabled"])
1520

16-
with pytest.raises(ValueError):
17-
convert_if_enum(enum_attr, "off")
21+
assert enum_index_to_value(attribute, 0) == "disabled"
22+
assert enum_index_to_value(attribute, 1) == "enabled"
23+
with pytest.raises(IndexError, match="Invalid index"):
24+
enum_index_to_value(attribute, 2)
25+
26+
with pytest.raises(ValueError, match="Cannot lookup value by index"):
27+
enum_index_to_value(AttrR(String()), 0)
28+
29+
30+
def test_enum_value_to_index():
31+
attribute = AttrR(String(), allowed_values=["disabled", "enabled"])
32+
33+
assert enum_value_to_index(attribute, "disabled") == 0
34+
assert enum_value_to_index(attribute, "enabled") == 1
35+
with pytest.raises(ValueError, match="not in allowed values"):
36+
enum_value_to_index(attribute, "off")
37+
38+
with pytest.raises(ValueError, match="Cannot convert value to index"):
39+
enum_value_to_index(AttrR(String()), "disabled")

0 commit comments

Comments
 (0)