Skip to content

Commit c376dec

Browse files
committed
add NodeSynchronizer for coordinating AST node evaluation
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 720748b commit c376dec

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

src/frequenz/sdk/timeseries/formulas/_base_ast_node.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"""Formula AST node base class."""
55

66
import abc
7+
import asyncio
78
from dataclasses import dataclass
9+
from datetime import datetime
810
from typing import Generic
911

1012
from typing_extensions import override
@@ -35,3 +37,64 @@ def __str__(self) -> str:
3537
@abc.abstractmethod
3638
async def subscribe(self) -> None:
3739
"""Subscribe to any data streams needed by this node."""
40+
41+
42+
class NodeSynchronizer(Generic[QuantityT]):
43+
"""A helper class to synchronize multiple AST nodes."""
44+
45+
def __init__(self) -> None:
46+
"""Initialize this instance."""
47+
self._synchronized: bool = False
48+
49+
async def evaluate(
50+
self,
51+
nodes: list[AstNode[QuantityT]],
52+
target_timestamp: datetime | None = None,
53+
) -> list[Sample[QuantityT] | QuantityT | None]:
54+
"""Synchronize and evaluate multiple AST nodes.
55+
56+
Args:
57+
nodes: The AST nodes to synchronize and evaluate.
58+
target_timestamp: An optional maximum timestamp to synchronize to.
59+
60+
Returns:
61+
A list containing the evaluated values of the nodes.
62+
63+
Raises:
64+
RuntimeError: If synchronization fails after multiple attempts.
65+
"""
66+
if not self._synchronized or target_timestamp is not None:
67+
_ = await asyncio.gather(*(node.subscribe() for node in nodes))
68+
values = [await node.evaluate() for node in nodes]
69+
70+
target_timestamp = max(
71+
(value.timestamp for value in values if isinstance(value, Sample)),
72+
default=None,
73+
)
74+
if target_timestamp is None:
75+
self._synchronized = True
76+
return values
77+
78+
for i, value in enumerate(values):
79+
if isinstance(value, Sample):
80+
ctr = 0
81+
while ctr < 10 and value.timestamp < target_timestamp:
82+
value = await nodes[i].evaluate()
83+
if not isinstance(value, Sample):
84+
raise RuntimeError(
85+
"Subsequent AST node evaluation did not return a Sample"
86+
)
87+
values[i] = value
88+
ctr += 1
89+
if ctr >= 10 and value.timestamp < target_timestamp:
90+
raise RuntimeError(
91+
"Could not synchronize AST node evaluations after 10 tries"
92+
)
93+
if value.timestamp > target_timestamp:
94+
values[i] = Sample(target_timestamp, None)
95+
96+
self._synchronized = True
97+
98+
return values
99+
100+
return [await node.evaluate() for node in nodes]

0 commit comments

Comments
 (0)