Skip to content

Commit d272263

Browse files
committed
Add an unsubscribe method to AST nodes
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 95c61ec commit d272263

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

src/frequenz/sdk/timeseries/formulas/_ast.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ async def subscribe(self) -> None:
8989
raise RuntimeError("Metric fetcher is not set for TelemetryStream node.")
9090
self._stream = await self.metric_fetcher()
9191

92+
@override
93+
async def unsubscribe(self) -> None:
94+
"""Unsubscribe from the telemetry stream for this component."""
95+
if self._stream is None:
96+
return
97+
_logger.debug("Unsubscribing from telemetry stream for %s", self.source)
98+
self._stream.close()
99+
self._stream = None
100+
92101
def _is_quantity_sample(
93102
self, sample: Sample[QuantityT] | Sample[Quantity]
94103
) -> TypeIs[Sample[Quantity]]:
@@ -116,6 +125,11 @@ async def subscribe(self) -> None:
116125
"""Subscribe to any data streams needed by the function."""
117126
await self.function.subscribe()
118127

128+
@override
129+
async def unsubscribe(self) -> None:
130+
"""Unsubscribe from any data streams needed by the function."""
131+
await self.function.unsubscribe()
132+
119133

120134
@dataclass(kw_only=True)
121135
class Constant(AstNode[QuantityT]):
@@ -135,7 +149,11 @@ def format(self, wrap: bool = False) -> str:
135149

136150
@override
137151
async def subscribe(self) -> None:
138-
"""Subscribe to any data streams needed by the function."""
152+
"""No-op for constant node."""
153+
154+
@override
155+
async def unsubscribe(self) -> None:
156+
"""No-op for constant node."""
139157

140158

141159
@dataclass(kw_only=True)
@@ -196,12 +214,20 @@ def format(self, wrap: bool = False) -> str:
196214

197215
@override
198216
async def subscribe(self) -> None:
199-
"""Subscribe to any data streams needed by the function."""
217+
"""Subscribe to any data streams needed by this node."""
200218
_ = await asyncio.gather(
201219
self.left.subscribe(),
202220
self.right.subscribe(),
203221
)
204222

223+
@override
224+
async def unsubscribe(self) -> None:
225+
"""Unsubscribe from any data streams needed by this node."""
226+
_ = await asyncio.gather(
227+
self.left.unsubscribe(),
228+
self.right.unsubscribe(),
229+
)
230+
205231

206232
@dataclass(kw_only=True)
207233
class Sub(AstNode[QuantityT]):
@@ -261,12 +287,20 @@ def format(self, wrap: bool = False) -> str:
261287

262288
@override
263289
async def subscribe(self) -> None:
264-
"""Subscribe to any data streams needed by the function."""
290+
"""Subscribe to any data streams needed by this node."""
265291
_ = await asyncio.gather(
266292
self.left.subscribe(),
267293
self.right.subscribe(),
268294
)
269295

296+
@override
297+
async def unsubscribe(self) -> None:
298+
"""Unsubscribe from any data streams needed by this node."""
299+
_ = await asyncio.gather(
300+
self.left.unsubscribe(),
301+
self.right.unsubscribe(),
302+
)
303+
270304

271305
@dataclass(kw_only=True)
272306
class Mul(AstNode[QuantityT]):
@@ -325,12 +359,20 @@ def format(self, wrap: bool = False) -> str:
325359

326360
@override
327361
async def subscribe(self) -> None:
328-
"""Subscribe to any data streams needed by the function."""
362+
"""Subscribe to any data streams needed by this node."""
329363
_ = await asyncio.gather(
330364
self.left.subscribe(),
331365
self.right.subscribe(),
332366
)
333367

368+
@override
369+
async def unsubscribe(self) -> None:
370+
"""Unsubscribe from any data streams needed by this node."""
371+
_ = await asyncio.gather(
372+
self.left.unsubscribe(),
373+
self.right.unsubscribe(),
374+
)
375+
334376

335377
@dataclass(kw_only=True)
336378
class Div(AstNode[QuantityT]):
@@ -388,8 +430,16 @@ def format(self, wrap: bool = False) -> str:
388430

389431
@override
390432
async def subscribe(self) -> None:
391-
"""Subscribe to any data streams needed by the function."""
433+
"""Subscribe to any data streams needed by this node."""
392434
_ = await asyncio.gather(
393435
self.left.subscribe(),
394436
self.right.subscribe(),
395437
)
438+
439+
@override
440+
async def unsubscribe(self) -> None:
441+
"""Unsubscribe from any data streams needed by this node."""
442+
_ = await asyncio.gather(
443+
self.left.unsubscribe(),
444+
self.right.unsubscribe(),
445+
)

src/frequenz/sdk/timeseries/formulas/_base_ast_node.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def __str__(self) -> str:
3838
async def subscribe(self) -> None:
3939
"""Subscribe to any data streams needed by this node."""
4040

41+
@abc.abstractmethod
42+
async def unsubscribe(self) -> None:
43+
"""Unsubscribe from any data streams used by this node."""
44+
4145

4246
class NodeSynchronizer(Generic[QuantityT]):
4347
"""A helper class to synchronize multiple AST nodes."""

src/frequenz/sdk/timeseries/formulas/_functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ async def subscribe(self) -> None:
4747
*(param.subscribe() for param in self.params),
4848
)
4949

50+
async def unsubscribe(self) -> None:
51+
"""Unsubscribe from any data streams needed by the function."""
52+
_ = await asyncio.gather(
53+
*(param.unsubscribe() for param in self.params),
54+
)
55+
5056
@classmethod
5157
def from_string(
5258
cls, name: str, params: list[AstNode[QuantityT]]

0 commit comments

Comments
 (0)