Skip to content

Commit be2f8fc

Browse files
committed
Add AttrR wait_for_value
1 parent 5658972 commit be2f8fc

File tree

5 files changed

+175
-3
lines changed

5 files changed

+175
-3
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
("py:class", "asyncio.events.AbstractEventLoop"),
7878
("py:class", "asyncio.streams.StreamReader"),
7979
("py:class", "asyncio.streams.StreamWriter"),
80+
("py:class", "asyncio.locks.Event"),
8081
# Annoying error:
8182
# docstring of collections.abc.Callable:1: WARNING:
8283
# 'any' reference target not found: self [ref.any]

src/fastcs/attributes/attr_r.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from fastcs.attributes.attribute import Attribute
88
from fastcs.attributes.attribute_io_ref import AttributeIORefT
9+
from fastcs.attributes.util import AttrValuePredicate, PredicateEvent
910
from fastcs.datatypes import DataType, DType_T
1011
from fastcs.logging import bind_logger
1112

@@ -39,6 +40,8 @@ def __init__(
3940
"""Callback to update the value of the attribute with an IO to the source"""
4041
self._on_update_callbacks: list[AttrOnUpdateCallback[DType_T]] | None = None
4142
"""Callbacks to publish changes to the value of the attribute"""
43+
self._on_update_events: set[PredicateEvent[DType_T]] = set()
44+
"""Events to set when the value satisifies some predicate"""
4245

4346
def get(self) -> DType_T:
4447
"""Get the cached value of the attribute."""
@@ -51,6 +54,9 @@ async def update(self, value: Any) -> None:
5154
generally only be called from an IO or a controller that is updating the value
5255
from some underlying source.
5356
57+
Any update callbacks will be called with the new value and any update events
58+
with predicates satisfied by the new value will be set.
59+
5460
To request a change to the setpoint of the attribute, use the ``put`` method,
5561
which will attempt to apply the change to the underlying source.
5662
@@ -67,6 +73,10 @@ async def update(self, value: Any) -> None:
6773

6874
self._value = self._datatype.validate(value)
6975

76+
self._on_update_events -= {
77+
e for e in self._on_update_events if e.set(self._value)
78+
}
79+
7080
if self._on_update_callbacks is not None:
7181
try:
7282
await asyncio.gather(
@@ -115,3 +125,69 @@ async def update_attribute():
115125
raise
116126

117127
return update_attribute
128+
129+
async def wait_for_predicate(
130+
self, predicate: AttrValuePredicate[DType_T], *, timeout: float
131+
):
132+
"""Wait for the predicate to be satisfied when called with the current value
133+
134+
Args:
135+
predicate: The predicate to test - a callable that takes the attribute
136+
value and returns True if the event should be set
137+
timeout: The timeout in seconds
138+
139+
"""
140+
if predicate(self._value):
141+
self.log_event(
142+
"Predicate already satisfied", predicate=predicate, attribute=self
143+
)
144+
return
145+
146+
self._on_update_events.add(update_event := PredicateEvent(predicate))
147+
148+
self.log_event("Waiting for predicate", predicate=predicate, attribute=self)
149+
try:
150+
await asyncio.wait_for(update_event.wait(), timeout)
151+
except TimeoutError:
152+
self._on_update_events.remove(update_event)
153+
raise TimeoutError(
154+
f"Timeout waiting {timeout}s for {self.full_name} predicate {predicate}"
155+
f" - current value: {self._value}"
156+
) from None
157+
158+
self.log_event("Predicate satisfied", predicate=predicate, attribute=self)
159+
160+
async def wait_for_value(self, target_value: DType_T, *, timeout: float):
161+
"""Wait for self._value to equal the target value
162+
163+
Args:
164+
target_value: The target value to wait for
165+
timeout: The timeout in seconds
166+
167+
Raises:
168+
TimeoutError: If the attribute does not reach the target value within the
169+
timeout
170+
171+
"""
172+
if self._value == target_value:
173+
self.log_event(
174+
"Current value already equals target value",
175+
target_value=target_value,
176+
attribute=self,
177+
)
178+
return
179+
180+
def predicate(v: DType_T) -> bool:
181+
return v == target_value
182+
183+
try:
184+
await self.wait_for_predicate(predicate, timeout=timeout)
185+
except TimeoutError:
186+
raise TimeoutError(
187+
f"Timeout waiting {timeout}s for {self.full_name} value {target_value}"
188+
f" - current value: {self._value}"
189+
) from None
190+
191+
self.log_event(
192+
"Value equals target value", target_valuevalue=target_value, attribute=self
193+
)

src/fastcs/attributes/attribute.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def name(self) -> str:
7070
def path(self) -> list[str]:
7171
return self._path
7272

73+
@property
74+
def full_name(self) -> str:
75+
return ".".join(self._path + [self._name])
76+
7377
def add_update_datatype_callback(
7478
self, callback: Callable[[DataType[DType_T]], None]
7579
) -> None:
@@ -102,7 +106,7 @@ def set_path(self, path: list[str]):
102106

103107
def __repr__(self):
104108
name = self.__class__.__name__
105-
path = ".".join(self._path + [self._name]) or None
109+
full_name = self.full_name or None
106110
datatype = self._datatype.__class__.__name__
107111

108-
return f"{name}(path={path}, datatype={datatype}, io_ref={self._io_ref})"
112+
return f"{name}(name={full_name}, datatype={datatype}, io_ref={self._io_ref})"

src/fastcs/attributes/util.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import asyncio
2+
from collections.abc import Callable
3+
from dataclasses import dataclass, field
4+
from typing import Generic
5+
6+
from fastcs.datatypes import DType_T
7+
8+
AttrValuePredicate = Callable[[DType_T], bool]
9+
10+
11+
@dataclass(eq=False)
12+
class PredicateEvent(Generic[DType_T]):
13+
"""A wrapper of `asyncio.Event` that only triggers when a predicate is satisfied"""
14+
15+
_predicate: AttrValuePredicate[DType_T]
16+
"""Predicate to filter set calls by"""
17+
_event: asyncio.Event = field(default_factory=asyncio.Event)
18+
"""Event to set"""
19+
20+
def set(self, value: DType_T) -> bool:
21+
"""Set the event if the predicate is satisfied by the value
22+
23+
Returns:
24+
`True` if the predicate was satisfied and the event was set, else `False`
25+
26+
"""
27+
if self._predicate(value):
28+
self._event.set()
29+
return True
30+
31+
return False
32+
33+
async def wait(self):
34+
"""Wait for the event to be set"""
35+
await self._event.wait()
36+
37+
def __hash__(self) -> int:
38+
"""Make instances unique when stored in sets"""
39+
return id(self)

tests/test_attributes.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from dataclasses import dataclass
23
from functools import partial
34
from typing import Generic, TypeVar
@@ -12,7 +13,7 @@
1213
NumberT = TypeVar("NumberT", int, float)
1314

1415

15-
def test_attribute():
16+
def test_attr_r():
1617
attr = AttrR(String(), group="test group")
1718

1819
with pytest.raises(RuntimeError):
@@ -39,6 +40,57 @@ def test_attribute():
3940
assert attr.get() == ""
4041

4142

43+
@pytest.mark.asyncio
44+
async def test_wait_for_predicate(mocker: MockerFixture):
45+
attr = AttrR(Int(), initial_value=0)
46+
47+
async def update(attr: AttrR):
48+
while True:
49+
await asyncio.sleep(0.1)
50+
await attr.update(attr.get() + 3) # 3, 6, 9, 12 != 10
51+
52+
asyncio.create_task(update(attr))
53+
54+
# We won't see exactly 10 so check for greater than
55+
def predicate(v: int) -> bool:
56+
return v > 10
57+
58+
wait_mock = mocker.spy(asyncio, "wait_for")
59+
with pytest.raises(TimeoutError):
60+
await attr.wait_for_predicate(predicate, timeout=0.2)
61+
62+
await attr.wait_for_predicate(predicate, timeout=1)
63+
64+
assert wait_mock.call_count == 2
65+
66+
# Returns immediately without creating event if value already as expected
67+
await attr.wait_for_predicate(predicate, timeout=1)
68+
assert wait_mock.call_count == 2
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_wait_for_value(mocker: MockerFixture):
73+
attr = AttrR(Int(), initial_value=0)
74+
75+
async def update(attr: AttrR):
76+
await asyncio.sleep(0.5)
77+
await attr.update(1)
78+
79+
asyncio.create_task(update(attr))
80+
81+
wait_mock = mocker.spy(asyncio, "wait_for")
82+
with pytest.raises(TimeoutError):
83+
await attr.wait_for_value(10, timeout=0.2)
84+
85+
await attr.wait_for_value(1, timeout=1)
86+
87+
assert wait_mock.call_count == 2
88+
89+
# Returns immediately without creating event if value already as expected
90+
await attr.wait_for_value(1, timeout=1)
91+
assert wait_mock.call_count == 2
92+
93+
4294
@pytest.mark.asyncio
4395
async def test_attributes():
4496
device = {"state": "Idle", "number": 1, "count": False}

0 commit comments

Comments
 (0)