Skip to content

Commit 97510c7

Browse files
committed
Propagate Sample types through AST evaluation
This changes `AstNode.evaluate()` to return `Sample[QuantityT] | QuantityT | None` instead of `float | None`. This makes timestamps available to coalesce node, so it knows how to synchronize newly started fallback streams with the primary streams. This also requires a `create_method` to be passed to TelemetryStream, for accurate typing, so the `Quantity` types from the resampler don't get sent out as they are, in case of simple formulas, because there is no top-level re-wrapping of the values in the formula evaluator anymore. Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 5e320a9 commit 97510c7

File tree

8 files changed

+292
-114
lines changed

8 files changed

+292
-114
lines changed

src/frequenz/sdk/timeseries/formulas/_ast.py

Lines changed: 177 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
from __future__ import annotations
77

88
import logging
9-
import math
10-
from collections.abc import AsyncIterator
9+
from collections.abc import AsyncIterator, Callable
1110
from dataclasses import dataclass
12-
from typing import Generic
1311

14-
from typing_extensions import override
12+
from frequenz.quantities import Quantity
13+
from typing_extensions import TypeIs, override
1514

1615
from ..._internal._math import is_close_to_zero
1716
from .._base_types import QuantityT, Sample
@@ -22,11 +21,12 @@
2221

2322

2423
@dataclass(kw_only=True)
25-
class TelemetryStream(AstNode, Generic[QuantityT]):
24+
class TelemetryStream(AstNode[QuantityT]):
2625
"""A AST node that retrieves values from a component's telemetry stream."""
2726

2827
source: str
29-
stream: AsyncIterator[Sample[QuantityT]]
28+
stream: AsyncIterator[Sample[QuantityT] | Sample[Quantity]]
29+
create_method: Callable[[float], QuantityT]
3030
_latest_sample: Sample[QuantityT] | None = None
3131

3232
@property
@@ -35,13 +35,11 @@ def latest_sample(self) -> Sample[QuantityT] | None:
3535
return self._latest_sample
3636

3737
@override
38-
def evaluate(self) -> float | None:
38+
def evaluate(self) -> Sample[QuantityT] | None:
3939
"""Return the base value of the latest sample for this component."""
4040
if self._latest_sample is None:
4141
raise ValueError("Next value has not been fetched yet.")
42-
if self._latest_sample.value is None:
43-
return None
44-
return self._latest_sample.value.base_value
42+
return self._latest_sample
4543

4644
@override
4745
def format(self, wrap: bool = False) -> str:
@@ -50,17 +48,30 @@ def format(self, wrap: bool = False) -> str:
5048

5149
async def fetch_next(self) -> None:
5250
"""Fetch the next value for this component and store it internally."""
53-
self._latest_sample = await anext(self.stream)
51+
latest_sample = await anext(self.stream)
52+
if self._is_quantity_sample(latest_sample):
53+
assert latest_sample.value is not None
54+
self._latest_sample = Sample(
55+
timestamp=latest_sample.timestamp,
56+
value=self.create_method(latest_sample.value.base_value),
57+
)
58+
else:
59+
self._latest_sample = latest_sample
60+
61+
def _is_quantity_sample(
62+
self, sample: Sample[QuantityT] | Sample[Quantity]
63+
) -> TypeIs[Sample[Quantity]]:
64+
return isinstance(sample.value, Quantity)
5465

5566

5667
@dataclass(kw_only=True)
57-
class FunCall(AstNode):
68+
class FunCall(AstNode[QuantityT]):
5869
"""A function call in the formula."""
5970

60-
function: Function
71+
function: Function[QuantityT]
6172

6273
@override
63-
def evaluate(self) -> float | None:
74+
def evaluate(self) -> Sample[QuantityT] | QuantityT | None:
6475
"""Evaluate the function call with its arguments."""
6576
return self.function()
6677

@@ -71,37 +82,67 @@ def format(self, wrap: bool = False) -> str:
7182

7283

7384
@dataclass(kw_only=True)
74-
class Constant(AstNode):
85+
class Constant(AstNode[QuantityT]):
7586
"""A constant numerical value in the formula."""
7687

77-
value: float
88+
value: QuantityT
7889

7990
@override
80-
def evaluate(self) -> float | None:
91+
def evaluate(self) -> QuantityT | None:
8192
"""Return the constant value."""
8293
return self.value
8394

8495
@override
8596
def format(self, wrap: bool = False) -> str:
8697
"""Return a string representation of the constant node."""
87-
return str(self.value)
98+
return str(self.value.base_value)
8899

89100

90101
@dataclass(kw_only=True)
91-
class Add(AstNode):
102+
class Add(AstNode[QuantityT]):
92103
"""Addition operation node."""
93104

94-
left: AstNode
95-
right: AstNode
105+
left: AstNode[QuantityT]
106+
right: AstNode[QuantityT]
96107

97108
@override
98-
def evaluate(self) -> float | None:
109+
def evaluate(self) -> Sample[QuantityT] | QuantityT | None:
99110
"""Evaluate the addition of the left and right nodes."""
100111
left = self.left.evaluate()
101112
right = self.right.evaluate()
102-
if left is None or right is None:
103-
return None
104-
return left + right
113+
match left, right:
114+
case Sample(), Sample():
115+
if left.value is None:
116+
return left
117+
if right.value is None:
118+
return right
119+
return Sample(
120+
timestamp=left.timestamp,
121+
value=left.value + right.value,
122+
)
123+
case Quantity(), Quantity():
124+
return left + right
125+
case (Sample(), Quantity()):
126+
return (
127+
left
128+
if left.value is None
129+
else Sample(
130+
timestamp=left.timestamp,
131+
value=left.value + right,
132+
)
133+
)
134+
case (Quantity(), Sample()):
135+
return (
136+
right
137+
if right.value is None
138+
else Sample(
139+
timestamp=right.timestamp,
140+
value=left + right.value,
141+
)
142+
)
143+
case (None, _) | (_, None):
144+
return None
145+
return None
105146

106147
@override
107148
def format(self, wrap: bool = False) -> str:
@@ -113,20 +154,51 @@ def format(self, wrap: bool = False) -> str:
113154

114155

115156
@dataclass(kw_only=True)
116-
class Sub(AstNode):
157+
class Sub(AstNode[QuantityT]):
117158
"""Subtraction operation node."""
118159

119-
left: AstNode
120-
right: AstNode
160+
left: AstNode[QuantityT]
161+
right: AstNode[QuantityT]
121162

122163
@override
123-
def evaluate(self) -> float | None:
164+
def evaluate(self) -> Sample[QuantityT] | QuantityT | None:
124165
"""Evaluate the subtraction of the right node from the left node."""
125166
left = self.left.evaluate()
126167
right = self.right.evaluate()
127-
if left is None or right is None:
128-
return None
129-
return left - right
168+
print("Sub.evaluate:", left, right)
169+
match left, right:
170+
case Sample(), Sample():
171+
if left.value is None:
172+
return left
173+
if right.value is None:
174+
return right
175+
return Sample(
176+
timestamp=left.timestamp,
177+
value=left.value - right.value,
178+
)
179+
case Quantity(), Quantity():
180+
return left - right
181+
case (Sample(), Quantity()):
182+
return (
183+
left
184+
if left.value is None
185+
else Sample(
186+
timestamp=left.timestamp,
187+
value=left.value - right,
188+
)
189+
)
190+
case (Quantity(), Sample()):
191+
return (
192+
right
193+
if right.value is None
194+
else Sample(
195+
timestamp=right.timestamp,
196+
value=left - right.value,
197+
)
198+
)
199+
case (None, _) | (_, None):
200+
return None
201+
return None
130202

131203
@override
132204
def format(self, wrap: bool = False) -> str:
@@ -138,20 +210,52 @@ def format(self, wrap: bool = False) -> str:
138210

139211

140212
@dataclass(kw_only=True)
141-
class Mul(AstNode):
213+
class Mul(AstNode[QuantityT]):
142214
"""Multiplication operation node."""
143215

144-
left: AstNode
145-
right: AstNode
216+
left: AstNode[QuantityT]
217+
right: AstNode[QuantityT]
146218

147219
@override
148-
def evaluate(self) -> float | None:
220+
def evaluate(self) -> Sample[QuantityT] | QuantityT | None:
149221
"""Evaluate the multiplication of the left and right nodes."""
150222
left = self.left.evaluate()
151223
right = self.right.evaluate()
152-
if left is None or right is None:
153-
return None
154-
return left * right
224+
match left, right:
225+
case Sample(), Sample():
226+
if left.value is None:
227+
return left
228+
if right.value is None:
229+
return right
230+
return Sample(
231+
timestamp=left.timestamp,
232+
value=left.value * right.value.base_value,
233+
)
234+
case Quantity(), Quantity():
235+
return left.__class__._new( # pylint: disable=protected-access
236+
left.base_value * right.base_value
237+
)
238+
case (Sample(), Quantity()):
239+
return (
240+
left
241+
if left.value is None
242+
else Sample(
243+
timestamp=left.timestamp,
244+
value=left.value * right.base_value,
245+
)
246+
)
247+
case (Quantity(), Sample()):
248+
return (
249+
right
250+
if right.value is None
251+
else Sample(
252+
timestamp=right.timestamp,
253+
value=right.value * left.base_value,
254+
)
255+
)
256+
case (None, _) | (_, None):
257+
return None
258+
return None
155259

156260
@override
157261
def format(self, wrap: bool = False) -> str:
@@ -160,22 +264,47 @@ def format(self, wrap: bool = False) -> str:
160264

161265

162266
@dataclass(kw_only=True)
163-
class Div(AstNode):
267+
class Div(AstNode[QuantityT]):
164268
"""Division operation node."""
165269

166-
left: AstNode
167-
right: AstNode
270+
left: AstNode[QuantityT]
271+
right: AstNode[QuantityT]
168272

169273
@override
170-
def evaluate(self) -> float | None:
274+
def evaluate(self) -> QuantityT | None:
171275
"""Evaluate the division of the left node by the right node."""
172276
left = self.left.evaluate()
173277
right = self.right.evaluate()
174-
if left is None or right is None:
175-
return None
176-
if is_close_to_zero(right):
177-
return math.nan
178-
return left / right
278+
match left, right:
279+
case Sample(), Sample():
280+
if left.value is None:
281+
return None
282+
if right.value is None:
283+
return None
284+
if is_close_to_zero(right.value.base_value):
285+
_logger.warning("Division by zero encountered in formula.")
286+
return None
287+
return left.value / right.value.base_value
288+
case Quantity(), Quantity():
289+
if is_close_to_zero(right.base_value):
290+
_logger.warning("Division by zero encountered in formula.")
291+
return None
292+
return left / right.base_value
293+
case (Sample(), Quantity()):
294+
if is_close_to_zero(right.base_value):
295+
_logger.warning("Division by zero encountered in formula.")
296+
return None
297+
return None if left.value is None else left.value / right.base_value
298+
case (Quantity(), Sample()):
299+
if right.value is None:
300+
return None
301+
if is_close_to_zero(right.value.base_value):
302+
_logger.warning("Division by zero encountered in formula.")
303+
return None
304+
return left / right.value.base_value
305+
case (None, _) | (_, None):
306+
return None
307+
return None
179308

180309
@override
181310
def format(self, wrap: bool = False) -> str:

src/frequenz/sdk/timeseries/formulas/_base_ast_node.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@
55

66
import abc
77
from dataclasses import dataclass
8+
from typing import Generic
89

910
from typing_extensions import override
1011

12+
from ...timeseries import Sample
13+
from ...timeseries._base_types import QuantityT
14+
1115

1216
@dataclass(kw_only=True)
13-
class AstNode(abc.ABC):
17+
class AstNode(abc.ABC, Generic[QuantityT]):
1418
"""An abstract syntax tree node representing a formula expression."""
1519

1620
span: tuple[int, int] | None = None
1721

1822
@abc.abstractmethod
19-
def evaluate(self) -> float | None:
23+
def evaluate(self) -> Sample[QuantityT] | QuantityT | None:
2024
"""Evaluate the expression and return its numerical value."""
2125

2226
@abc.abstractmethod

0 commit comments

Comments
 (0)