Skip to content

Commit 6f775c4

Browse files
committed
Add a Pipe implementation
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent bfcee7b commit 6f775c4

File tree

4 files changed

+158
-0
lines changed

4 files changed

+158
-0
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Experimental channel primitives.
5+
6+
Warning:
7+
This package contains experimental channel primitives that are not yet
8+
considered stable. They are subject to change without notice, including
9+
removal, even in minor updates.
10+
"""
11+
12+
from ._pipe import Pipe
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Pipe between a receiver and a sender.
5+
6+
The `Pipe` class takes a receiver and a sender and creates a pipe between them by
7+
forwarding all the messages received by the receiver to the sender.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import asyncio
13+
import typing
14+
15+
from .._generic import ChannelMessageT
16+
from .._receiver import Receiver
17+
from .._sender import Sender
18+
19+
20+
class Pipe(typing.Generic[ChannelMessageT]):
21+
"""A pipe between two channels.
22+
23+
The `Pipe` class takes a receiver and a sender and creates a pipe between them
24+
by forwarding all the messages received by the receiver to the sender.
25+
26+
Example:
27+
```python
28+
from frequenz.channels import Broadcast, Pipe
29+
30+
channel1: Broadcast[int] = Broadcast(name="channel1")
31+
channel2: Broadcast[int] = Broadcast(name="channel2")
32+
33+
receiver_chan1 = channel1.new_receiver()
34+
sender_chan2 = channel2.new_sender()
35+
36+
async with Pipe(channel2.new_receiver(), channel1.new_sender()):
37+
await sender_chan2.send(10)
38+
assert await receiver_chan1.receive() == 10
39+
```
40+
"""
41+
42+
def __init__(
43+
self, receiver: Receiver[ChannelMessageT], sender: Sender[ChannelMessageT]
44+
) -> None:
45+
"""Create a new pipe between two channels.
46+
47+
Args:
48+
receiver: The receiver channel.
49+
sender: The sender channel.
50+
"""
51+
self._sender = sender
52+
self._receiver = receiver
53+
self._task: asyncio.Task[None] | None = None
54+
55+
async def __aenter__(self) -> Pipe[ChannelMessageT]:
56+
"""Enter the runtime context."""
57+
await self.start()
58+
return self
59+
60+
async def __aexit__(
61+
self,
62+
_exc_type: typing.Type[BaseException],
63+
_exc: BaseException,
64+
_tb: typing.Any,
65+
) -> None:
66+
"""Exit the runtime context."""
67+
await self.stop()
68+
69+
async def start(self) -> None:
70+
"""Start this pipe if it is not already running."""
71+
if not self._task or self._task.done():
72+
self._task = asyncio.create_task(self._run())
73+
74+
async def stop(self) -> None:
75+
"""Stop this pipe."""
76+
if self._task and not self._task.done():
77+
self._task.cancel()
78+
try:
79+
await self._task
80+
except asyncio.CancelledError:
81+
pass
82+
83+
async def _run(self) -> None:
84+
async for value in self._receiver:
85+
await self._sender.send(value)

tests/experimental/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for experimental channel primitives."""

tests/experimental/test_pipe.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for the Pipe class."""
5+
6+
7+
import asyncio
8+
import typing
9+
10+
from frequenz.channels import Broadcast, Receiver
11+
from frequenz.channels.experimental import Pipe
12+
13+
T = typing.TypeVar("T")
14+
15+
16+
class Timeout:
17+
"""Sentinel for timeout."""
18+
19+
20+
async def receive_timeout(recv: Receiver[T], timeout: float = 0.1) -> T | type[Timeout]:
21+
"""Receive message from receiver with timeout."""
22+
try:
23+
return await asyncio.wait_for(recv.receive(), timeout=timeout)
24+
except asyncio.TimeoutError:
25+
return Timeout
26+
27+
28+
async def test_pipe() -> None:
29+
"""Test pipe."""
30+
channel1: Broadcast[int] = Broadcast(name="channel1")
31+
channel2: Broadcast[int] = Broadcast(name="channel2")
32+
33+
sender_chan1 = channel1.new_sender()
34+
sender_chan2 = channel2.new_sender()
35+
receiver_chan1 = channel1.new_receiver()
36+
receiver_chan2 = channel2.new_receiver()
37+
38+
async with Pipe(channel2.new_receiver(), channel1.new_sender()):
39+
await sender_chan2.send(42)
40+
assert await receive_timeout(receiver_chan1) == 42
41+
assert await receive_timeout(receiver_chan2) == 42
42+
43+
await sender_chan2.send(-2)
44+
assert await receive_timeout(receiver_chan1) == -2
45+
assert await receive_timeout(receiver_chan2) == -2
46+
47+
await sender_chan1.send(43)
48+
assert await receive_timeout(receiver_chan1) == 43
49+
assert await receive_timeout(receiver_chan2) is Timeout
50+
51+
await sender_chan2.send(5)
52+
assert await receive_timeout(receiver_chan1) == 5
53+
assert await receive_timeout(receiver_chan2) == 5
54+
55+
await sender_chan2.send(5)
56+
assert await receive_timeout(receiver_chan1) is Timeout
57+
assert await receive_timeout(receiver_chan2) == 5

0 commit comments

Comments
 (0)