Skip to content

Commit 2149b01

Browse files
committed
Add AttrR wait_for_value
1 parent b55948e commit 2149b01

File tree

4 files changed

+165
-1
lines changed

4 files changed

+165
-1
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
("py:class", "asyncio.events.AbstractEventLoop"),
7878
("py:class", "asyncio.streams.StreamReader"),
7979
("py:class", "asyncio.streams.StreamWriter"),
80+
("py:class", "asyncio.locks.Event"),
8081
# Annoying error:
8182
# docstring of collections.abc.Callable:1: WARNING:
8283
# 'any' reference target not found: self [ref.any]

src/fastcs/attributes/attr_r.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from fastcs.attributes.attribute import Attribute
88
from fastcs.attributes.attribute_io_ref import AttributeIORefT
9+
from fastcs.attributes.util import AttributeValuePredicate, PredicateEvent
910
from fastcs.datatypes import DataType, DType_T
1011
from fastcs.logging import bind_logger
1112

@@ -39,6 +40,8 @@ def __init__(
3940
"""Callback to update the value of the attribute with an IO to the source"""
4041
self._on_update_callbacks: list[AttrOnUpdateCallback[DType_T]] | None = None
4142
"""Callbacks to publish changes to the value of the attribute"""
43+
self._on_update_events: set[PredicateEvent[DType_T]] = set()
44+
"""Events to set when the value satisifies some predicate"""
4245

4346
def get(self) -> DType_T:
4447
"""Get the cached value of the attribute."""
@@ -67,6 +70,10 @@ async def update(self, value: Any) -> None:
6770

6871
self._value = self._datatype.validate(value)
6972

73+
self._on_update_events -= {
74+
e for e in self._on_update_events if e.set(self._value)
75+
}
76+
7077
if self._on_update_callbacks is not None:
7178
try:
7279
await asyncio.gather(
@@ -78,6 +85,19 @@ async def update(self, value: Any) -> None:
7885
)
7986
raise
8087

88+
def _register_update_event(
89+
self, predicate: AttributeValuePredicate[DType_T], event: asyncio.Event
90+
):
91+
"""Register an event to be set when the value satisfies a predicate
92+
93+
Args:
94+
predicate: The predicate to check - a callable that takes the attribute
95+
value and returns True if the event should be set
96+
event: The event to set
97+
98+
"""
99+
self._on_update_events.add(PredicateEvent(predicate, event))
100+
81101
def add_on_update_callback(self, callback: AttrOnUpdateCallback[DType_T]) -> None:
82102
"""Add a callback to be called when the value of the attribute is updated
83103
@@ -115,3 +135,61 @@ async def update_attribute():
115135
raise
116136

117137
return update_attribute
138+
139+
async def wait_for_predicate(
140+
self, predicate: AttributeValuePredicate[DType_T], *, timeout: float
141+
):
142+
"""Wait for the predicate to return True when called with the current value
143+
144+
Args:
145+
predicate: The predicate to check - a callable that takes the attribute
146+
value and returns True if the event should be set
147+
timeout: The timeout in seconds
148+
149+
"""
150+
if predicate(self._value):
151+
self.log_event(
152+
"Predicate already satisfied", predicate=predicate, attribute=self
153+
)
154+
return
155+
156+
self._register_update_event(predicate, update_event := asyncio.Event())
157+
158+
self.log_event("Waiting for predicate", predicate=predicate, attribute=self)
159+
try:
160+
await asyncio.wait_for(update_event.wait(), timeout)
161+
except TimeoutError:
162+
raise TimeoutError(
163+
f"Timeout waiting for predicate {predicate}. "
164+
f"Current value: {self._value}"
165+
) from None
166+
167+
self.log_event("Predicate satisfied", predicate=predicate, attribute=self)
168+
169+
async def wait_for_value(self, value: DType_T, *, timeout: float):
170+
"""Wait for value to change to the required value
171+
172+
Args:
173+
value: The value to wait for
174+
timeout: The timeout in seconds
175+
176+
Raises:
177+
TimeoutError: If the attribute does not reach the required value within the
178+
timeout
179+
180+
"""
181+
if self._value == value:
182+
self.log_event("Value already equal", value=value, attribute=self)
183+
return
184+
185+
def predicate(v: DType_T) -> bool:
186+
return v == value
187+
188+
try:
189+
await self.wait_for_predicate(predicate, timeout=timeout)
190+
except TimeoutError:
191+
raise TimeoutError(
192+
f"Timeout waiting for value {value}. Current value: {self._value}"
193+
) from None
194+
195+
self.log_event("Value equal", value=value, attribute=self)

src/fastcs/attributes/util.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import asyncio
2+
from collections.abc import Callable
3+
from dataclasses import dataclass
4+
from typing import Generic
5+
6+
from fastcs.datatypes import DType_T
7+
8+
AttributeValuePredicate = Callable[[DType_T], bool]
9+
10+
11+
@dataclass(eq=False)
12+
class PredicateEvent(Generic[DType_T]):
13+
"""A wrapper of an asyncio.Event that only triggers when a predicate is true"""
14+
15+
_predicate: AttributeValuePredicate[DType_T]
16+
"""Predicate to filter set calls by"""
17+
_event: asyncio.Event
18+
"""Event to set when the predicate returns True"""
19+
20+
def set(self, value: DType_T):
21+
"""Set the event if the predicate returns True
22+
23+
Returns:
24+
True if the predicate was True and the event was set, False otherwise
25+
26+
"""
27+
if self._predicate(value):
28+
self._event.set()
29+
return True
30+
31+
return False
32+
33+
def __hash__(self) -> int:
34+
return id(self)

tests/test_attributes.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from dataclasses import dataclass
23
from functools import partial
34
from typing import Generic, TypeVar
@@ -12,7 +13,7 @@
1213
NumberT = TypeVar("NumberT", int, float)
1314

1415

15-
def test_attribute():
16+
def test_attr_r():
1617
attr = AttrR(String(), group="test group")
1718

1819
with pytest.raises(RuntimeError):
@@ -39,6 +40,56 @@ def test_attribute():
3940
assert attr.get() == ""
4041

4142

43+
@pytest.mark.asyncio
44+
async def test_wait_for_predicate(mocker: MockerFixture):
45+
attr = AttrR(Int(), initial_value=0)
46+
47+
async def update(attr: AttrR):
48+
while True:
49+
await asyncio.sleep(0.1)
50+
await attr.update(attr.get() + 3)
51+
52+
asyncio.create_task(update(attr))
53+
54+
def predicate(v: int) -> bool:
55+
return v > 10
56+
57+
register_spy = mocker.spy(attr, "_register_update_event")
58+
with pytest.raises(TimeoutError):
59+
await attr.wait_for_predicate(predicate, timeout=0.2)
60+
61+
await attr.wait_for_predicate(predicate, timeout=1)
62+
63+
assert register_spy.call_count == 2
64+
65+
# Returns immediately without creating event if value already as expected
66+
await attr.wait_for_predicate(predicate, timeout=1)
67+
assert register_spy.call_count == 2
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_wait_for_value(mocker: MockerFixture):
72+
attr = AttrR(Int(), initial_value=0)
73+
74+
async def update(attr: AttrR):
75+
await asyncio.sleep(0.5)
76+
await attr.update(1)
77+
78+
asyncio.create_task(update(attr))
79+
80+
register_spy = mocker.spy(attr, "_register_update_event")
81+
with pytest.raises(TimeoutError):
82+
await attr.wait_for_value(1, timeout=0.2)
83+
84+
await attr.wait_for_value(1, timeout=1)
85+
86+
assert register_spy.call_count == 2
87+
88+
# Returns immediately without creating event if value already as expected
89+
await attr.wait_for_value(1, timeout=1)
90+
assert register_spy.call_count == 2
91+
92+
4293
@pytest.mark.asyncio
4394
async def test_attributes():
4495
device = {"state": "Idle", "number": 1, "count": False}

0 commit comments

Comments
 (0)