Skip to content

Commit a7dc25c

Browse files
committed
Extract Node base class to separate module
Also rename it to AstNode This way, nodes can be accessed by the _function.py module, without circular imports. That would enable functions to handle their own arguments. Signed-off-by: Sahas Subramanian <[email protected]>
1 parent c1a4668 commit a7dc25c

File tree

5 files changed

+67
-55
lines changed

5 files changed

+67
-55
lines changed

src/frequenz/sdk/timeseries/formulas/_ast.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from __future__ import annotations
77

8-
import abc
98
import logging
109
import math
1110
from collections.abc import AsyncIterator
@@ -16,33 +15,14 @@
1615

1716
from ..._internal._math import is_close_to_zero
1817
from .._base_types import QuantityT, Sample
18+
from ._base_ast_node import AstNode
1919
from ._functions import Function
2020

2121
_logger = logging.getLogger(__name__)
2222

2323

2424
@dataclass(kw_only=True)
25-
class Node(abc.ABC):
26-
"""An abstract syntax tree node representing a formula expression."""
27-
28-
span: tuple[int, int] | None = None
29-
30-
@abc.abstractmethod
31-
def evaluate(self) -> float | None:
32-
"""Evaluate the expression and return its numerical value."""
33-
34-
@abc.abstractmethod
35-
def format(self, wrap: bool = False) -> str:
36-
"""Return a string representation of the node."""
37-
38-
@override
39-
def __str__(self) -> str:
40-
"""Return the string representation of the node."""
41-
return self.format()
42-
43-
44-
@dataclass(kw_only=True)
45-
class TelemetryStream(Node, Generic[QuantityT]):
25+
class TelemetryStream(AstNode, Generic[QuantityT]):
4626
"""A AST node that retrieves values from a component's telemetry stream."""
4727

4828
source: str
@@ -74,11 +54,11 @@ async def fetch_next(self) -> None:
7454

7555

7656
@dataclass(kw_only=True)
77-
class FunCall(Node):
57+
class FunCall(AstNode):
7858
"""A function call in the formula."""
7959

8060
function: Function
81-
args: list[Node]
61+
args: list[AstNode]
8262

8363
@override
8464
def evaluate(self) -> float | None:
@@ -93,7 +73,7 @@ def format(self, wrap: bool = False) -> str:
9373

9474

9575
@dataclass(kw_only=True)
96-
class Constant(Node):
76+
class Constant(AstNode):
9777
"""A constant numerical value in the formula."""
9878

9979
value: float
@@ -110,11 +90,11 @@ def format(self, wrap: bool = False) -> str:
11090

11191

11292
@dataclass(kw_only=True)
113-
class Add(Node):
93+
class Add(AstNode):
11494
"""Addition operation node."""
11595

116-
left: Node
117-
right: Node
96+
left: AstNode
97+
right: AstNode
11898

11999
@override
120100
def evaluate(self) -> float | None:
@@ -135,11 +115,11 @@ def format(self, wrap: bool = False) -> str:
135115

136116

137117
@dataclass(kw_only=True)
138-
class Sub(Node):
118+
class Sub(AstNode):
139119
"""Subtraction operation node."""
140120

141-
left: Node
142-
right: Node
121+
left: AstNode
122+
right: AstNode
143123

144124
@override
145125
def evaluate(self) -> float | None:
@@ -160,11 +140,11 @@ def format(self, wrap: bool = False) -> str:
160140

161141

162142
@dataclass(kw_only=True)
163-
class Mul(Node):
143+
class Mul(AstNode):
164144
"""Multiplication operation node."""
165145

166-
left: Node
167-
right: Node
146+
left: AstNode
147+
right: AstNode
168148

169149
@override
170150
def evaluate(self) -> float | None:
@@ -182,11 +162,11 @@ def format(self, wrap: bool = False) -> str:
182162

183163

184164
@dataclass(kw_only=True)
185-
class Div(Node):
165+
class Div(AstNode):
186166
"""Division operation node."""
187167

188-
left: Node
189-
right: Node
168+
left: AstNode
169+
right: AstNode
190170

191171
@override
192172
def evaluate(self) -> float | None:
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# License: MIT
2+
# Copyright © 2025 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Formula AST node base class."""
5+
6+
import abc
7+
from dataclasses import dataclass
8+
9+
from typing_extensions import override
10+
11+
12+
@dataclass(kw_only=True)
13+
class AstNode(abc.ABC):
14+
"""An abstract syntax tree node representing a formula expression."""
15+
16+
span: tuple[int, int] | None = None
17+
18+
@abc.abstractmethod
19+
def evaluate(self) -> float | None:
20+
"""Evaluate the expression and return its numerical value."""
21+
22+
@abc.abstractmethod
23+
def format(self, wrap: bool = False) -> str:
24+
"""Return a string representation of the node."""
25+
26+
@override
27+
def __str__(self) -> str:
28+
"""Return the string representation of the node."""
29+
return self.format()

src/frequenz/sdk/timeseries/formulas/_formula.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .. import ReceiverFetcher, Sample
2121
from .._base_types import QuantityT
2222
from . import _ast
23+
from ._base_ast_node import AstNode
2324
from ._formula_evaluator import FormulaEvaluatingActor
2425
from ._functions import Coalesce, Max, Min
2526

@@ -33,7 +34,7 @@ def __init__( # pylint: disable=too-many-arguments
3334
self,
3435
*,
3536
name: str,
36-
root: _ast.Node,
37+
root: AstNode,
3738
create_method: Callable[[float], QuantityT],
3839
streams: list[_ast.TelemetryStream[QuantityT]],
3940
sub_formulas: list[Formula[QuantityT]] | None = None,
@@ -54,7 +55,7 @@ def __init__( # pylint: disable=too-many-arguments
5455
"""
5556
BackgroundService.__init__(self)
5657
self._name: str = name
57-
self._root: _ast.Node = root
58+
self._root: AstNode = root
5859
self._components: list[_ast.TelemetryStream[QuantityT]] = streams
5960
self._create_method: Callable[[float], QuantityT] = create_method
6061
self._sub_formulas: list[Formula[QuantityT]] = sub_formulas or []
@@ -154,7 +155,7 @@ class FormulaBuilder(Generic[QuantityT]):
154155

155156
def __init__(
156157
self,
157-
formula: Formula[QuantityT] | _ast.Node,
158+
formula: Formula[QuantityT] | AstNode,
158159
create_method: Callable[[float], QuantityT],
159160
streams: list[_ast.TelemetryStream[QuantityT]] | None = None,
160161
sub_formulas: list[Formula[QuantityT]] | None = None,
@@ -176,7 +177,7 @@ def __init__(
176177
"""Sub-formulas whose lifetimes are managed by this formula."""
177178

178179
if isinstance(formula, Formula):
179-
self.root: _ast.Node = _ast.TelemetryStream(
180+
self.root: AstNode = _ast.TelemetryStream(
180181
source=str(formula),
181182
stream=formula.new_receiver(),
182183
)
@@ -266,7 +267,7 @@ def coalesce(
266267
other: list[FormulaBuilder[QuantityT] | QuantityT | Formula[QuantityT]],
267268
) -> FormulaBuilder[QuantityT]:
268269
"""Create a coalesce operation node."""
269-
right_nodes: list[_ast.Node] = []
270+
right_nodes: list[AstNode] = []
270271
for item in other:
271272
if isinstance(item, FormulaBuilder):
272273
right_nodes.append(item.root)
@@ -299,7 +300,7 @@ def min(
299300
other: list[FormulaBuilder[QuantityT] | QuantityT | Formula[QuantityT]],
300301
) -> FormulaBuilder[QuantityT]:
301302
"""Create a min operation node."""
302-
right_nodes: list[_ast.Node] = []
303+
right_nodes: list[AstNode] = []
303304
for item in other:
304305
if isinstance(item, FormulaBuilder):
305306
right_nodes.append(item.root)
@@ -332,7 +333,7 @@ def max(
332333
other: list[FormulaBuilder[QuantityT] | QuantityT | Formula[QuantityT]],
333334
) -> FormulaBuilder[QuantityT]:
334335
"""Create a max operation node."""
335-
right_nodes: list[_ast.Node] = []
336+
right_nodes: list[AstNode] = []
336337
for item in other:
337338
if isinstance(item, FormulaBuilder):
338339
right_nodes.append(item.root)

src/frequenz/sdk/timeseries/formulas/_formula_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ...actor import Actor
1717
from .._base_types import QuantityT, Sample
1818
from . import _ast
19+
from ._base_ast_node import AstNode
1920
from ._resampled_stream_fetcher import ResampledStreamFetcher
2021

2122
_logger = logging.getLogger(__name__)
@@ -27,7 +28,7 @@ class FormulaEvaluatingActor(Generic[QuantityT], Actor):
2728
def __init__( # pylint: disable=too-many-arguments
2829
self,
2930
*,
30-
root: _ast.Node,
31+
root: AstNode,
3132
components: list[_ast.TelemetryStream[QuantityT]],
3233
create_method: Callable[[float], QuantityT],
3334
output_channel: Broadcast[Sample[QuantityT]],
@@ -47,7 +48,7 @@ def __init__( # pylint: disable=too-many-arguments
4748
"""
4849
super().__init__()
4950

50-
self._root: _ast.Node = root
51+
self._root: AstNode = root
5152
self._components: list[_ast.TelemetryStream[QuantityT]] = components
5253
self._create_method: Callable[[float], QuantityT] = create_method
5354
self._metric_fetcher: ResampledStreamFetcher | None = metric_fetcher

src/frequenz/sdk/timeseries/formulas/_parser.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from frequenz.sdk.timeseries._base_types import QuantityT
1414

1515
from . import _ast, _token
16+
from ._base_ast_node import AstNode
1617
from ._formula import Formula
1718
from ._functions import Function
1819
from ._lexer import Lexer
@@ -63,7 +64,7 @@ def __init__(
6364
self._components: list[_ast.TelemetryStream[QuantityT]] = []
6465
self._create_method: Callable[[float], QuantityT] = create_method
6566

66-
def _parse_term(self) -> _ast.Node | None:
67+
def _parse_term(self) -> AstNode | None:
6768
factor = self._parse_factor()
6869
if factor is None:
6970
return None
@@ -87,7 +88,7 @@ def _parse_term(self) -> _ast.Node | None:
8788

8889
return factor
8990

90-
def _parse_factor(self) -> _ast.Node | None:
91+
def _parse_factor(self) -> AstNode | None:
9192
unary = self._parse_unary()
9293

9394
if unary is None:
@@ -109,11 +110,11 @@ def _parse_factor(self) -> _ast.Node | None:
109110

110111
return unary
111112

112-
def _parse_unary(self) -> _ast.Node | None:
113+
def _parse_unary(self) -> AstNode | None:
113114
token: _token.Token | None = self._lexer.peek()
114115
if token is not None and isinstance(token, _token.Minus):
115116
token = next(self._lexer)
116-
primary: _ast.Node | None = self._parse_primary()
117+
primary: AstNode | None = self._parse_primary()
117118
if primary is None:
118119
raise ValueError(
119120
f"Expected primary expression after unary '-' at position {token.span}"
@@ -124,11 +125,11 @@ def _parse_unary(self) -> _ast.Node | None:
124125

125126
return self._parse_primary()
126127

127-
def _parse_bracketed(self) -> _ast.Node | None:
128+
def _parse_bracketed(self) -> AstNode | None:
128129
oparen = next(self._lexer) # consume '('
129130
assert isinstance(oparen, _token.OpenParen)
130131

131-
expr: _ast.Node | None = self._parse_term()
132+
expr: AstNode | None = self._parse_term()
132133
if expr is None:
133134
raise ValueError(f"Expected expression after '(' at position {oparen.span}")
134135

@@ -140,9 +141,9 @@ def _parse_bracketed(self) -> _ast.Node | None:
140141

141142
return expr
142143

143-
def _parse_function_call(self) -> _ast.Node | None:
144+
def _parse_function_call(self) -> AstNode | None:
144145
fn_name: _token.Token = next(self._lexer)
145-
args: list[_ast.Node] = []
146+
args: list[AstNode] = []
146147

147148
token: _token.Token | None = self._lexer.peek()
148149
if token is None or not isinstance(token, _token.OpenParen):
@@ -176,7 +177,7 @@ def _parse_function_call(self) -> _ast.Node | None:
176177
args=args,
177178
)
178179

179-
def _parse_primary(self) -> _ast.Node | None:
180+
def _parse_primary(self) -> AstNode | None:
180181
token: _token.Token | None = self._lexer.peek()
181182
if token is None:
182183
return None

0 commit comments

Comments
 (0)