Skip to content

Commit a57233a

Browse files
committed
Add AttrR wait_for_value
1 parent 5658972 commit a57233a

File tree

4 files changed

+169
-1
lines changed

4 files changed

+169
-1
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
("py:class", "asyncio.events.AbstractEventLoop"),
7878
("py:class", "asyncio.streams.StreamReader"),
7979
("py:class", "asyncio.streams.StreamWriter"),
80+
("py:class", "asyncio.locks.Event"),
8081
# Annoying error:
8182
# docstring of collections.abc.Callable:1: WARNING:
8283
# 'any' reference target not found: self [ref.any]

src/fastcs/attributes/attr_r.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from fastcs.attributes.attribute import Attribute
88
from fastcs.attributes.attribute_io_ref import AttributeIORefT
9+
from fastcs.attributes.util import AttrValuePredicate, PredicateEvent
910
from fastcs.datatypes import DataType, DType_T
1011
from fastcs.logging import bind_logger
1112

@@ -39,6 +40,8 @@ def __init__(
3940
"""Callback to update the value of the attribute with an IO to the source"""
4041
self._on_update_callbacks: list[AttrOnUpdateCallback[DType_T]] | None = None
4142
"""Callbacks to publish changes to the value of the attribute"""
43+
self._on_update_events: set[PredicateEvent[DType_T]] = set()
44+
"""Events to set when the value satisifies some predicate"""
4245

4346
def get(self) -> DType_T:
4447
"""Get the cached value of the attribute."""
@@ -67,6 +70,10 @@ async def update(self, value: Any) -> None:
6770

6871
self._value = self._datatype.validate(value)
6972

73+
self._on_update_events -= {
74+
e for e in self._on_update_events if e.set(self._value)
75+
}
76+
7077
if self._on_update_callbacks is not None:
7178
try:
7279
await asyncio.gather(
@@ -78,6 +85,16 @@ async def update(self, value: Any) -> None:
7885
)
7986
raise
8087

88+
def _register_update_event(self, event: PredicateEvent):
89+
"""Register an event to be set when the value satisfies a predicate
90+
91+
Args:
92+
predicate: The predicate to filter event set by
93+
event: The event to set
94+
95+
"""
96+
self._on_update_events.add(event)
97+
8198
def add_on_update_callback(self, callback: AttrOnUpdateCallback[DType_T]) -> None:
8299
"""Add a callback to be called when the value of the attribute is updated
83100
@@ -115,3 +132,62 @@ async def update_attribute():
115132
raise
116133

117134
return update_attribute
135+
136+
async def wait_for_predicate(
137+
self, predicate: AttrValuePredicate[DType_T], *, timeout: float
138+
):
139+
"""Wait for the predicate to be satisfied when called with the current value
140+
141+
Args:
142+
predicate: The predicate to test - a callable that takes the attribute
143+
value and returns True if the event should be set
144+
timeout: The timeout in seconds
145+
146+
"""
147+
if predicate(self._value):
148+
self.log_event(
149+
"Predicate already satisfied", predicate=predicate, attribute=self
150+
)
151+
return
152+
153+
self._register_update_event(update_event := PredicateEvent(predicate))
154+
155+
self.log_event("Waiting for predicate", predicate=predicate, attribute=self)
156+
try:
157+
await asyncio.wait_for(update_event.wait(), timeout)
158+
except TimeoutError:
159+
self._on_update_events.remove(update_event)
160+
raise TimeoutError(
161+
f"Timeout waiting for predicate {predicate}. "
162+
f"Current value: {self._value}"
163+
) from None
164+
165+
self.log_event("Predicate satisfied", predicate=predicate, attribute=self)
166+
167+
async def wait_for_value(self, value: DType_T, *, timeout: float):
168+
"""Wait for value to equal the required value
169+
170+
Args:
171+
value: The value to wait for
172+
timeout: The timeout in seconds
173+
174+
Raises:
175+
TimeoutError: If the attribute does not reach the required value within the
176+
timeout
177+
178+
"""
179+
if self._value == value:
180+
self.log_event("Value already equal", value=value, attribute=self)
181+
return
182+
183+
def predicate(v: DType_T) -> bool:
184+
return v == value
185+
186+
try:
187+
await self.wait_for_predicate(predicate, timeout=timeout)
188+
except TimeoutError:
189+
raise TimeoutError(
190+
f"Timeout waiting for value {value}. Current value: {self._value}"
191+
) from None
192+
193+
self.log_event("Value equal", value=value, attribute=self)

src/fastcs/attributes/util.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import asyncio
2+
from collections.abc import Callable
3+
from dataclasses import dataclass, field
4+
from typing import Generic
5+
6+
from fastcs.datatypes import DType_T
7+
8+
AttrValuePredicate = Callable[[DType_T], bool]
9+
10+
11+
@dataclass(eq=False)
12+
class PredicateEvent(Generic[DType_T]):
13+
"""A wrapper of an asyncio.Event that only triggers when a predicate is satisfied"""
14+
15+
_predicate: AttrValuePredicate[DType_T]
16+
"""Predicate to filter set calls by"""
17+
_event: asyncio.Event = field(default_factory=asyncio.Event)
18+
"""Event to set"""
19+
20+
def set(self, value: DType_T):
21+
"""Set the event if the predicate is satisfied by the value
22+
23+
Returns:
24+
`True` if the event was set, else `False`
25+
26+
"""
27+
if self._predicate(value):
28+
self._event.set()
29+
return True
30+
31+
return False
32+
33+
async def wait(self):
34+
"""Wait for the event to be set"""
35+
await self._event.wait()
36+
37+
def __hash__(self) -> int:
38+
"""Make instances unique when stored in sets"""
39+
return id(self)

tests/test_attributes.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from dataclasses import dataclass
23
from functools import partial
34
from typing import Generic, TypeVar
@@ -12,7 +13,7 @@
1213
NumberT = TypeVar("NumberT", int, float)
1314

1415

15-
def test_attribute():
16+
def test_attr_r():
1617
attr = AttrR(String(), group="test group")
1718

1819
with pytest.raises(RuntimeError):
@@ -39,6 +40,57 @@ def test_attribute():
3940
assert attr.get() == ""
4041

4142

43+
@pytest.mark.asyncio
44+
async def test_wait_for_predicate(mocker: MockerFixture):
45+
attr = AttrR(Int(), initial_value=0)
46+
47+
async def update(attr: AttrR):
48+
while True:
49+
await asyncio.sleep(0.1)
50+
await attr.update(attr.get() + 3) # 3, 6, 9, 12 != 10
51+
52+
asyncio.create_task(update(attr))
53+
54+
# We won't see exactly 10 so check for greater than
55+
def predicate(v: int) -> bool:
56+
return v > 10
57+
58+
register_spy = mocker.spy(attr, "_register_update_event")
59+
with pytest.raises(TimeoutError):
60+
await attr.wait_for_predicate(predicate, timeout=0.2)
61+
62+
await attr.wait_for_predicate(predicate, timeout=1)
63+
64+
assert register_spy.call_count == 2
65+
66+
# Returns immediately without creating event if value already as expected
67+
await attr.wait_for_predicate(predicate, timeout=1)
68+
assert register_spy.call_count == 2
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_wait_for_value(mocker: MockerFixture):
73+
attr = AttrR(Int(), initial_value=0)
74+
75+
async def update(attr: AttrR):
76+
await asyncio.sleep(0.5)
77+
await attr.update(1)
78+
79+
asyncio.create_task(update(attr))
80+
81+
register_spy = mocker.spy(attr, "_register_update_event")
82+
with pytest.raises(TimeoutError):
83+
await attr.wait_for_value(1, timeout=0.2)
84+
85+
await attr.wait_for_value(1, timeout=1)
86+
87+
assert register_spy.call_count == 2
88+
89+
# Returns immediately without creating event if value already as expected
90+
await attr.wait_for_value(1, timeout=1)
91+
assert register_spy.call_count == 2
92+
93+
4294
@pytest.mark.asyncio
4395
async def test_attributes():
4496
device = {"state": "Idle", "number": 1, "count": False}

0 commit comments

Comments
 (0)