Skip to content

Commit 9b2a037

Browse files
committed
Add AttrR wait_for_value
1 parent 5658972 commit 9b2a037

File tree

4 files changed

+160
-1
lines changed

4 files changed

+160
-1
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
("py:class", "asyncio.events.AbstractEventLoop"),
7878
("py:class", "asyncio.streams.StreamReader"),
7979
("py:class", "asyncio.streams.StreamWriter"),
80+
("py:class", "asyncio.locks.Event"),
8081
# Annoying error:
8182
# docstring of collections.abc.Callable:1: WARNING:
8283
# 'any' reference target not found: self [ref.any]

src/fastcs/attributes/attr_r.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from fastcs.attributes.attribute import Attribute
88
from fastcs.attributes.attribute_io_ref import AttributeIORefT
9+
from fastcs.attributes.util import AttrValuePredicate, PredicateEvent
910
from fastcs.datatypes import DataType, DType_T
1011
from fastcs.logging import bind_logger
1112

@@ -39,6 +40,8 @@ def __init__(
3940
"""Callback to update the value of the attribute with an IO to the source"""
4041
self._on_update_callbacks: list[AttrOnUpdateCallback[DType_T]] | None = None
4142
"""Callbacks to publish changes to the value of the attribute"""
43+
self._on_update_events: set[PredicateEvent[DType_T]] = set()
44+
"""Events to set when the value satisifies some predicate"""
4245

4346
def get(self) -> DType_T:
4447
"""Get the cached value of the attribute."""
@@ -67,6 +70,10 @@ async def update(self, value: Any) -> None:
6770

6871
self._value = self._datatype.validate(value)
6972

73+
self._on_update_events -= {
74+
e for e in self._on_update_events if e.set(self._value)
75+
}
76+
7077
if self._on_update_callbacks is not None:
7178
try:
7279
await asyncio.gather(
@@ -115,3 +122,62 @@ async def update_attribute():
115122
raise
116123

117124
return update_attribute
125+
126+
async def wait_for_predicate(
127+
self, predicate: AttrValuePredicate[DType_T], *, timeout: float
128+
):
129+
"""Wait for the predicate to be satisfied when called with the current value
130+
131+
Args:
132+
predicate: The predicate to test - a callable that takes the attribute
133+
value and returns True if the event should be set
134+
timeout: The timeout in seconds
135+
136+
"""
137+
if predicate(self._value):
138+
self.log_event(
139+
"Predicate already satisfied", predicate=predicate, attribute=self
140+
)
141+
return
142+
143+
self._on_update_events.add(update_event := PredicateEvent(predicate))
144+
145+
self.log_event("Waiting for predicate", predicate=predicate, attribute=self)
146+
try:
147+
await asyncio.wait_for(update_event.wait(), timeout)
148+
except TimeoutError:
149+
self._on_update_events.remove(update_event)
150+
raise TimeoutError(
151+
f"Timeout waiting for predicate {predicate}. "
152+
f"Current value: {self._value}"
153+
) from None
154+
155+
self.log_event("Predicate satisfied", predicate=predicate, attribute=self)
156+
157+
async def wait_for_value(self, value: DType_T, *, timeout: float):
158+
"""Wait for value to equal the required value
159+
160+
Args:
161+
value: The value to wait for
162+
timeout: The timeout in seconds
163+
164+
Raises:
165+
TimeoutError: If the attribute does not reach the required value within the
166+
timeout
167+
168+
"""
169+
if self._value == value:
170+
self.log_event("Value already equal", value=value, attribute=self)
171+
return
172+
173+
def predicate(v: DType_T) -> bool:
174+
return v == value
175+
176+
try:
177+
await self.wait_for_predicate(predicate, timeout=timeout)
178+
except TimeoutError:
179+
raise TimeoutError(
180+
f"Timeout waiting for value {value}. Current value: {self._value}"
181+
) from None
182+
183+
self.log_event("Value equal", value=value, attribute=self)

src/fastcs/attributes/util.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import asyncio
2+
from collections.abc import Callable
3+
from dataclasses import dataclass, field
4+
from typing import Generic
5+
6+
from fastcs.datatypes import DType_T
7+
8+
AttrValuePredicate = Callable[[DType_T], bool]
9+
10+
11+
@dataclass(eq=False)
12+
class PredicateEvent(Generic[DType_T]):
13+
"""A wrapper of `asyncio.Event` that only triggers when a predicate is satisfied"""
14+
15+
_predicate: AttrValuePredicate[DType_T]
16+
"""Predicate to filter set calls by"""
17+
_event: asyncio.Event = field(default_factory=asyncio.Event)
18+
"""Event to set"""
19+
20+
def set(self, value: DType_T):
21+
"""Set the event if the predicate is satisfied by the value
22+
23+
Returns:
24+
`True` if the event was set, else `False`
25+
26+
"""
27+
if self._predicate(value):
28+
self._event.set()
29+
return True
30+
31+
return False
32+
33+
async def wait(self):
34+
"""Wait for the event to be set"""
35+
await self._event.wait()
36+
37+
def __hash__(self) -> int:
38+
"""Make instances unique when stored in sets"""
39+
return id(self)

tests/test_attributes.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from dataclasses import dataclass
23
from functools import partial
34
from typing import Generic, TypeVar
@@ -12,7 +13,7 @@
1213
NumberT = TypeVar("NumberT", int, float)
1314

1415

15-
def test_attribute():
16+
def test_attr_r():
1617
attr = AttrR(String(), group="test group")
1718

1819
with pytest.raises(RuntimeError):
@@ -39,6 +40,58 @@ def test_attribute():
3940
assert attr.get() == ""
4041

4142

43+
@pytest.mark.timeout(5)
44+
@pytest.mark.asyncio
45+
async def test_wait_for_predicate(mocker: MockerFixture):
46+
attr = AttrR(Int(), initial_value=0)
47+
48+
async def update(attr: AttrR):
49+
while True:
50+
await asyncio.sleep(0.1)
51+
await attr.update(attr.get() + 3) # 3, 6, 9, 12 != 10
52+
53+
asyncio.create_task(update(attr))
54+
55+
# We won't see exactly 10 so check for greater than
56+
def predicate(v: int) -> bool:
57+
return v > 10
58+
59+
wait_mock = mocker.spy(asyncio, "wait_for")
60+
with pytest.raises(TimeoutError):
61+
await attr.wait_for_predicate(predicate, timeout=0.2)
62+
63+
await attr.wait_for_predicate(predicate, timeout=1)
64+
65+
assert wait_mock.call_count == 2
66+
67+
# Returns immediately without creating event if value already as expected
68+
await attr.wait_for_predicate(predicate, timeout=1)
69+
assert wait_mock.call_count == 2
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_wait_for_value(mocker: MockerFixture):
74+
attr = AttrR(Int(), initial_value=0)
75+
76+
async def update(attr: AttrR):
77+
await asyncio.sleep(0.5)
78+
await attr.update(1)
79+
80+
asyncio.create_task(update(attr))
81+
82+
wait_mock = mocker.spy(asyncio, "wait_for")
83+
with pytest.raises(TimeoutError):
84+
await attr.wait_for_value(10, timeout=0.2)
85+
86+
await attr.wait_for_value(1, timeout=1)
87+
88+
assert wait_mock.call_count == 2
89+
90+
# Returns immediately without creating event if value already as expected
91+
await attr.wait_for_value(1, timeout=1)
92+
assert wait_mock.call_count == 2
93+
94+
4295
@pytest.mark.asyncio
4396
async def test_attributes():
4497
device = {"state": "Idle", "number": 1, "count": False}

0 commit comments

Comments
 (0)