66from __future__ import annotations
77
88import logging
9- import math
10- from collections .abc import AsyncIterator
9+ from collections .abc import AsyncIterator , Callable
1110from dataclasses import dataclass
12- from typing import Generic
1311
14- from typing_extensions import override
12+ from frequenz .quantities import Quantity
13+ from typing_extensions import TypeIs , override
1514
1615from ..._internal ._math import is_close_to_zero
1716from .._base_types import QuantityT , Sample
2221
2322
2423@dataclass (kw_only = True )
25- class TelemetryStream (AstNode , Generic [QuantityT ]):
24+ class TelemetryStream (AstNode [QuantityT ]):
2625 """A AST node that retrieves values from a component's telemetry stream."""
2726
2827 source : str
29- stream : AsyncIterator [Sample [QuantityT ]]
28+ stream : AsyncIterator [Sample [QuantityT ] | Sample [Quantity ]]
29+ create_method : Callable [[float ], QuantityT ]
3030 _latest_sample : Sample [QuantityT ] | None = None
3131
3232 @property
@@ -35,13 +35,11 @@ def latest_sample(self) -> Sample[QuantityT] | None:
3535 return self ._latest_sample
3636
3737 @override
38- def evaluate (self ) -> float | None :
38+ def evaluate (self ) -> Sample [ QuantityT ] | None :
3939 """Return the base value of the latest sample for this component."""
4040 if self ._latest_sample is None :
4141 raise ValueError ("Next value has not been fetched yet." )
42- if self ._latest_sample .value is None :
43- return None
44- return self ._latest_sample .value .base_value
42+ return self ._latest_sample
4543
4644 @override
4745 def format (self , wrap : bool = False ) -> str :
@@ -50,17 +48,30 @@ def format(self, wrap: bool = False) -> str:
5048
5149 async def fetch_next (self ) -> None :
5250 """Fetch the next value for this component and store it internally."""
53- self ._latest_sample = await anext (self .stream )
51+ latest_sample = await anext (self .stream )
52+ if self ._is_quantity_sample (latest_sample ):
53+ assert latest_sample .value is not None
54+ self ._latest_sample = Sample (
55+ timestamp = latest_sample .timestamp ,
56+ value = self .create_method (latest_sample .value .base_value ),
57+ )
58+ else :
59+ self ._latest_sample = latest_sample
60+
61+ def _is_quantity_sample (
62+ self , sample : Sample [QuantityT ] | Sample [Quantity ]
63+ ) -> TypeIs [Sample [Quantity ]]:
64+ return isinstance (sample .value , Quantity )
5465
5566
5667@dataclass (kw_only = True )
57- class FunCall (AstNode ):
68+ class FunCall (AstNode [ QuantityT ] ):
5869 """A function call in the formula."""
5970
60- function : Function
71+ function : Function [ QuantityT ]
6172
6273 @override
63- def evaluate (self ) -> float | None :
74+ def evaluate (self ) -> Sample [ QuantityT ] | QuantityT | None :
6475 """Evaluate the function call with its arguments."""
6576 return self .function ()
6677
@@ -71,37 +82,67 @@ def format(self, wrap: bool = False) -> str:
7182
7283
7384@dataclass (kw_only = True )
74- class Constant (AstNode ):
85+ class Constant (AstNode [ QuantityT ] ):
7586 """A constant numerical value in the formula."""
7687
77- value : float
88+ value : QuantityT
7889
7990 @override
80- def evaluate (self ) -> float | None :
91+ def evaluate (self ) -> QuantityT | None :
8192 """Return the constant value."""
8293 return self .value
8394
8495 @override
8596 def format (self , wrap : bool = False ) -> str :
8697 """Return a string representation of the constant node."""
87- return str (self .value )
98+ return str (self .value . base_value )
8899
89100
90101@dataclass (kw_only = True )
91- class Add (AstNode ):
102+ class Add (AstNode [ QuantityT ] ):
92103 """Addition operation node."""
93104
94- left : AstNode
95- right : AstNode
105+ left : AstNode [ QuantityT ]
106+ right : AstNode [ QuantityT ]
96107
97108 @override
98- def evaluate (self ) -> float | None :
109+ def evaluate (self ) -> Sample [ QuantityT ] | QuantityT | None :
99110 """Evaluate the addition of the left and right nodes."""
100111 left = self .left .evaluate ()
101112 right = self .right .evaluate ()
102- if left is None or right is None :
103- return None
104- return left + right
113+ match left , right :
114+ case Sample (), Sample ():
115+ if left .value is None :
116+ return left
117+ if right .value is None :
118+ return right
119+ return Sample (
120+ timestamp = left .timestamp ,
121+ value = left .value + right .value ,
122+ )
123+ case Quantity (), Quantity ():
124+ return left + right
125+ case (Sample (), Quantity ()):
126+ return (
127+ left
128+ if left .value is None
129+ else Sample (
130+ timestamp = left .timestamp ,
131+ value = left .value + right ,
132+ )
133+ )
134+ case (Quantity (), Sample ()):
135+ return (
136+ right
137+ if right .value is None
138+ else Sample (
139+ timestamp = right .timestamp ,
140+ value = left + right .value ,
141+ )
142+ )
143+ case (None , _) | (_, None ):
144+ return None
145+ return None
105146
106147 @override
107148 def format (self , wrap : bool = False ) -> str :
@@ -113,20 +154,51 @@ def format(self, wrap: bool = False) -> str:
113154
114155
115156@dataclass (kw_only = True )
116- class Sub (AstNode ):
157+ class Sub (AstNode [ QuantityT ] ):
117158 """Subtraction operation node."""
118159
119- left : AstNode
120- right : AstNode
160+ left : AstNode [ QuantityT ]
161+ right : AstNode [ QuantityT ]
121162
122163 @override
123- def evaluate (self ) -> float | None :
164+ def evaluate (self ) -> Sample [ QuantityT ] | QuantityT | None :
124165 """Evaluate the subtraction of the right node from the left node."""
125166 left = self .left .evaluate ()
126167 right = self .right .evaluate ()
127- if left is None or right is None :
128- return None
129- return left - right
168+ print ("Sub.evaluate:" , left , right )
169+ match left , right :
170+ case Sample (), Sample ():
171+ if left .value is None :
172+ return left
173+ if right .value is None :
174+ return right
175+ return Sample (
176+ timestamp = left .timestamp ,
177+ value = left .value - right .value ,
178+ )
179+ case Quantity (), Quantity ():
180+ return left - right
181+ case (Sample (), Quantity ()):
182+ return (
183+ left
184+ if left .value is None
185+ else Sample (
186+ timestamp = left .timestamp ,
187+ value = left .value - right ,
188+ )
189+ )
190+ case (Quantity (), Sample ()):
191+ return (
192+ right
193+ if right .value is None
194+ else Sample (
195+ timestamp = right .timestamp ,
196+ value = left - right .value ,
197+ )
198+ )
199+ case (None , _) | (_, None ):
200+ return None
201+ return None
130202
131203 @override
132204 def format (self , wrap : bool = False ) -> str :
@@ -138,20 +210,52 @@ def format(self, wrap: bool = False) -> str:
138210
139211
140212@dataclass (kw_only = True )
141- class Mul (AstNode ):
213+ class Mul (AstNode [ QuantityT ] ):
142214 """Multiplication operation node."""
143215
144- left : AstNode
145- right : AstNode
216+ left : AstNode [ QuantityT ]
217+ right : AstNode [ QuantityT ]
146218
147219 @override
148- def evaluate (self ) -> float | None :
220+ def evaluate (self ) -> Sample [ QuantityT ] | QuantityT | None :
149221 """Evaluate the multiplication of the left and right nodes."""
150222 left = self .left .evaluate ()
151223 right = self .right .evaluate ()
152- if left is None or right is None :
153- return None
154- return left * right
224+ match left , right :
225+ case Sample (), Sample ():
226+ if left .value is None :
227+ return left
228+ if right .value is None :
229+ return right
230+ return Sample (
231+ timestamp = left .timestamp ,
232+ value = left .value * right .value .base_value ,
233+ )
234+ case Quantity (), Quantity ():
235+ return left .__class__ ._new ( # pylint: disable=protected-access
236+ left .base_value * right .base_value
237+ )
238+ case (Sample (), Quantity ()):
239+ return (
240+ left
241+ if left .value is None
242+ else Sample (
243+ timestamp = left .timestamp ,
244+ value = left .value * right .base_value ,
245+ )
246+ )
247+ case (Quantity (), Sample ()):
248+ return (
249+ right
250+ if right .value is None
251+ else Sample (
252+ timestamp = right .timestamp ,
253+ value = right .value * left .base_value ,
254+ )
255+ )
256+ case (None , _) | (_, None ):
257+ return None
258+ return None
155259
156260 @override
157261 def format (self , wrap : bool = False ) -> str :
@@ -160,22 +264,47 @@ def format(self, wrap: bool = False) -> str:
160264
161265
162266@dataclass (kw_only = True )
163- class Div (AstNode ):
267+ class Div (AstNode [ QuantityT ] ):
164268 """Division operation node."""
165269
166- left : AstNode
167- right : AstNode
270+ left : AstNode [ QuantityT ]
271+ right : AstNode [ QuantityT ]
168272
169273 @override
170- def evaluate (self ) -> float | None :
274+ def evaluate (self ) -> QuantityT | None :
171275 """Evaluate the division of the left node by the right node."""
172276 left = self .left .evaluate ()
173277 right = self .right .evaluate ()
174- if left is None or right is None :
175- return None
176- if is_close_to_zero (right ):
177- return math .nan
178- return left / right
278+ match left , right :
279+ case Sample (), Sample ():
280+ if left .value is None :
281+ return None
282+ if right .value is None :
283+ return None
284+ if is_close_to_zero (right .value .base_value ):
285+ _logger .warning ("Division by zero encountered in formula." )
286+ return None
287+ return left .value / right .value .base_value
288+ case Quantity (), Quantity ():
289+ if is_close_to_zero (right .base_value ):
290+ _logger .warning ("Division by zero encountered in formula." )
291+ return None
292+ return left / right .base_value
293+ case (Sample (), Quantity ()):
294+ if is_close_to_zero (right .base_value ):
295+ _logger .warning ("Division by zero encountered in formula." )
296+ return None
297+ return None if left .value is None else left .value / right .base_value
298+ case (Quantity (), Sample ()):
299+ if right .value is None :
300+ return None
301+ if is_close_to_zero (right .value .base_value ):
302+ _logger .warning ("Division by zero encountered in formula." )
303+ return None
304+ return left / right .value .base_value
305+ case (None , _) | (_, None ):
306+ return None
307+ return None
179308
180309 @override
181310 def format (self , wrap : bool = False ) -> str :
0 commit comments