Skip to content

Commit a06702b

Browse files
committed
Implement lazy subscription for Coalesce function
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 631e5d2 commit a06702b

File tree

4 files changed

+111
-4
lines changed

4 files changed

+111
-4
lines changed

src/frequenz/sdk/timeseries/formulas/_functions.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import abc
99
import asyncio
10+
import logging
1011
from dataclasses import dataclass, field
1112
from datetime import datetime
1213
from typing import Generic
@@ -15,8 +16,11 @@
1516
from typing_extensions import override
1617

1718
from .._base_types import QuantityT, Sample
19+
from ._ast import Constant
1820
from ._base_ast_node import AstNode, NodeSynchronizer
1921

22+
_logger = logging.getLogger(__name__)
23+
2024

2125
@dataclass(kw_only=True)
2226
class FunCall(AstNode[QuantityT]):
@@ -99,6 +103,8 @@ def from_string(
99103
class Coalesce(Function[QuantityT]):
100104
"""A function that returns the first non-None argument."""
101105

106+
num_subscribed: int = 0
107+
102108
@property
103109
@override
104110
def name(self) -> str:
@@ -110,22 +116,70 @@ async def __call__(self) -> Sample[QuantityT] | QuantityT | None:
110116
"""Return the first non-None argument."""
111117
ts: datetime | None = None
112118

113-
args = await self._synchronizer.evaluate(self.params)
114-
for arg in args:
119+
if self.num_subscribed == 0:
120+
await self._subscribe_next()
121+
122+
args = await self._synchronizer.evaluate(
123+
self.params[: self.num_subscribed], sync_to_first_node=True
124+
)
125+
for ctr, arg in enumerate(args, start=1):
115126
match arg:
116127
case Sample(timestamp, value):
117128
if value is not None:
129+
# Found a non-None value, unsubscribe from subsequent params
130+
if ctr < self.num_subscribed:
131+
await self._unsubscribe_after(ctr)
118132
return arg
119133
ts = timestamp
120134
case Quantity():
135+
# Found a non-None value, unsubscribe from subsequent params
136+
if ctr < self.num_subscribed:
137+
await self._unsubscribe_after(ctr)
121138
if ts is not None:
122139
return Sample(timestamp=ts, value=arg)
123140
return arg
124141
case None:
125142
continue
143+
# Don't have a non-None value yet, subscribe to the next parameter for
144+
# next time and return None for now, unless the next value is a constant.
145+
next_value: Sample[QuantityT] | QuantityT | None = None
146+
await self._subscribe_next()
147+
148+
if isinstance(self.params[self.num_subscribed - 1], Constant):
149+
next_value = await self.params[self.num_subscribed - 1].evaluate()
150+
if isinstance(next_value, Sample):
151+
return next_value
152+
126153
if ts is not None:
127-
return Sample(timestamp=ts, value=None)
128-
return None
154+
return Sample(timestamp=ts, value=next_value)
155+
return next_value
156+
157+
@override
158+
async def subscribe(self) -> None:
159+
"""Subscribe to the first parameter if not already subscribed."""
160+
if self.num_subscribed == 0:
161+
await self._subscribe_next()
162+
163+
async def _subscribe_next(self) -> None:
164+
"""Subscribe to the next parameter."""
165+
if self.num_subscribed < len(self.params):
166+
_logger.debug(
167+
"Coalesce subscribing to param %d: %s",
168+
self.num_subscribed + 1,
169+
self.params[self.num_subscribed],
170+
)
171+
await self.params[self.num_subscribed].subscribe()
172+
self.num_subscribed += 1
173+
174+
async def _unsubscribe_after(self, index: int) -> None:
175+
"""Unsubscribe from parameters after the given index."""
176+
for param in self.params[index:]:
177+
_logger.debug(
178+
"Coalesce unsubscribing from param: %s",
179+
param,
180+
)
181+
await param.unsubscribe()
182+
self.num_subscribed = index
129183

130184

131185
class Max(Function[QuantityT]):

tests/timeseries/_formulas/test_formula_composition.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,29 @@ async def test_formula_composition_min_max(self, mocker: MockerFixture) -> None:
230230
stack.push_async_callback(formula_max.stop)
231231
formula_max_rx = formula_max.new_receiver()
232232

233+
assert (
234+
str(formula_min)
235+
== "[grid_power_min]("
236+
+ "MIN([grid_power](COALESCE(#4, #7)), [chp_power](COALESCE(#5, #7, 0.0)))"
237+
+ ")"
238+
)
239+
assert (
240+
str(formula_max)
241+
== "[grid_power_max]("
242+
+ "MAX([grid_power](COALESCE(#4, #7)), [chp_power](COALESCE(#5, #7, 0.0)))"
243+
+ ")"
244+
)
245+
246+
await mockgrid.mock_resampler.send_meter_power([100.0, 200.0])
247+
await mockgrid.mock_resampler.send_chp_power([None])
248+
# Because it got None for CHP (#5), it will then subscribe to the meter (#7).
249+
# So we have to send again so it uses the meter value.
233250
await mockgrid.mock_resampler.send_meter_power([100.0, 200.0])
234251
await mockgrid.mock_resampler.send_chp_power([None])
235252

236253
# Test min
237254
min_pow = await formula_min_rx.receive()
255+
min_pow = await formula_min_rx.receive()
238256
assert (
239257
min_pow
240258
and min_pow.value
@@ -243,17 +261,21 @@ async def test_formula_composition_min_max(self, mocker: MockerFixture) -> None:
243261

244262
# Test max
245263
max_pow = await formula_max_rx.receive()
264+
max_pow = await formula_max_rx.receive()
246265
assert (
247266
max_pow
248267
and max_pow.value
249268
and max_pow.value.isclose(Power.from_watts(200.0))
250269
)
251270

271+
await mockgrid.mock_resampler.send_meter_power([-100.0, -200.0])
272+
await mockgrid.mock_resampler.send_chp_power([None])
252273
await mockgrid.mock_resampler.send_meter_power([-100.0, -200.0])
253274
await mockgrid.mock_resampler.send_chp_power([None])
254275

255276
# Test min
256277
min_pow = await formula_min_rx.receive()
278+
min_pow = await formula_min_rx.receive()
257279
assert (
258280
min_pow
259281
and min_pow.value
@@ -262,6 +284,7 @@ async def test_formula_composition_min_max(self, mocker: MockerFixture) -> None:
262284

263285
# Test max
264286
max_pow = await formula_max_rx.receive()
287+
max_pow = await formula_max_rx.receive()
265288
assert (
266289
max_pow
267290
and max_pow.value

tests/timeseries/_formulas/test_formulas.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""Tests for the Formula implementation."""
55

66
import asyncio
7+
import logging
78
from collections import OrderedDict
89
from collections.abc import Callable
910
from datetime import datetime, timedelta
@@ -22,6 +23,8 @@
2223
ResampledStreamFetcher,
2324
)
2425

26+
_logger = logging.getLogger(__name__)
27+
2528

2629
@pytest.fixture
2730
def event_loop_policy() -> async_solipsism.EventLoopPolicy:
@@ -40,6 +43,7 @@ async def run_test( # pylint: disable=too-many-locals
4043
io_pairs: list[tuple[list[float | None], float | None]],
4144
) -> None:
4245
"""Run a formula test."""
46+
_logger.debug("TESTING FORMULA: %s", formula_str)
4347
channels: OrderedDict[ComponentId, Broadcast[Sample[Quantity]]] = OrderedDict()
4448
for comp_id in component_ids:
4549
channels[ComponentId(comp_id)] = Broadcast(
@@ -67,6 +71,7 @@ def stream_recv(comp_id: ComponentId) -> Receiver[Sample[Quantity]]:
6771

6872
for io_pair in io_pairs:
6973
io_input, io_output = io_pair
74+
now += timedelta(seconds=1)
7075
_ = await asyncio.gather(
7176
*[
7277
chan.new_sender().send(
@@ -90,6 +95,7 @@ def stream_recv(comp_id: ComponentId) -> Receiver[Sample[Quantity]]:
9095
and next_val.value.base_value == io_output
9196
)
9297
tests_passed += 1
98+
_logger.debug("%s: Passed inputs: %s", tests_passed, io_input)
9399
assert tests_passed == len(io_pairs)
94100

95101
async def test_simple(self) -> None:
@@ -287,6 +293,17 @@ async def test_max_min_coalesce(self) -> None:
287293
([10.0, 12.0, 15.0], 25.0),
288294
],
289295
)
296+
await self.run_test(
297+
"#2 + COALESCE(#4, #5, 0.0)",
298+
"[f](#2 + COALESCE(#4, #5, 0.0))",
299+
[2, 4, 5],
300+
[
301+
([10.0, 12.0, 15.0], 22.0),
302+
([10.0, None, 15.0], None),
303+
([10.0, None, 15.0], 25.0),
304+
([10.0, None, None], 10.0),
305+
],
306+
)
290307
await self.run_test(
291308
"MIN(#2, #4) + COALESCE(#5, 0.0)",
292309
"[f](MIN(#2, #4) + COALESCE(#5, 0.0))",
@@ -370,6 +387,8 @@ def stream_recv(comp_id: int) -> Receiver[Sample[Quantity]]:
370387

371388
assert str(formula) == expected
372389

390+
_logger.debug("TESTING FORMULA: %s", expected)
391+
373392
result_chan = formula.new_receiver()
374393
await asyncio.sleep(0.1)
375394
now = datetime.now()
@@ -394,6 +413,7 @@ def stream_recv(comp_id: int) -> Receiver[Sample[Quantity]]:
394413
and next_val.value.base_value == io_output
395414
)
396415
tests_passed += 1
416+
_logger.debug("%s: Passed inputs: %s", tests_passed, io_input)
397417
await formula.stop()
398418
assert tests_passed == len(io_pairs)
399419

@@ -562,7 +582,9 @@ async def test_coalesce(self) -> None:
562582
lambda c2, c4, c5: c2.coalesce([c4, c5]),
563583
"[l2](COALESCE([0](#0), [1](#1), [2](#2)))",
564584
[
585+
([None, 12.0, 15.0], None),
565586
([None, 12.0, 15.0], 12.0),
587+
([None, None, 15.0], None),
566588
([None, None, 15.0], 15.0),
567589
([10.0, None, 15.0], 10.0),
568590
([None, None, None], None),
@@ -574,9 +596,14 @@ async def test_coalesce(self) -> None:
574596
lambda c2, c4, c5: (c2 * 5.0).coalesce([c4 / 2.0, c5]),
575597
"[l2](COALESCE([0](#0) * 5.0, [1](#1) / 2.0, [2](#2)))",
576598
[
599+
([None, 12.0, 15.0], None),
577600
([None, 12.0, 15.0], 6.0),
601+
([None, None, 15.0], None),
578602
([None, None, 15.0], 15.0),
579603
([10.0, None, 15.0], 50.0),
604+
([None, None, 15.0], None),
605+
([None, None, 15.0], None),
606+
([None, None, 15.0], 15.0),
580607
([None, None, None], None),
581608
],
582609
)

tests/timeseries/_formulas/test_formulas_3_phase.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def stream_recv(comp_id: int) -> Receiver[Sample[Quantity]]:
8282
)
8383
builder = make_builder(p3_formula_1, p3_formula_2)
8484
formula = builder.build("l2")
85+
8586
receiver = formula.new_receiver()
8687

8788
await asyncio.sleep(0.1)
@@ -154,7 +155,9 @@ async def test_composition(self) -> None:
154155
await self.run_test(
155156
[
156157
([(4.0, 9.0, 16.0), (2.0, 13.0, 4.0)], (4.0, 9.0, 16.0)),
158+
([(-5.0, 10.0, None), (10.0, 5.0, None)], (-5.0, 10.0, None)),
157159
([(-5.0, 10.0, None), (10.0, 5.0, None)], (-5.0, 10.0, 0.0)),
160+
([(None, 2.0, 3.0), (2.0, None, 4.0)], (None, 2.0, 3.0)),
158161
([(None, 2.0, 3.0), (2.0, None, 4.0)], (2.0, 2.0, 3.0)),
159162
],
160163
lambda f1, f2: f1.coalesce(

0 commit comments

Comments
 (0)