Skip to content

Commit 2267927

Browse files
authored
Support broadening or narrowing of types in Receivers and Senders (#262)
This allows Receivers to reduce message types to be broader than actual, and senders to accept narrower types than required. ```python @DataClass class Broader: """A broad class.""" value: int class Actual(Broader): """Actual class.""" class Narrower(Actual): """A narrower class.""" chan = Broadcast[Actual](name="input-chan") sender: Sender[Narrower] = chan.new_sender() receiver: Receiver[Broader] = chan.new_receiver() ```
2 parents 977002c + 25751ba commit 2267927

File tree

4 files changed

+56
-23
lines changed

4 files changed

+56
-23
lines changed

RELEASE_NOTES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@
147147

148148
- `map()`: The returned map object now has a more useful implementation of `__str__ and `__repr__`.
149149

150+
## Improvements
151+
152+
* `Receiver`: Use a covariant generic type, which allows the generic type to be broader than the actual type.
153+
154+
* `Sender`: Use a contravariant generic type, which allows the generic type to be narrower than the required type.
155+
150156
## Bug Fixes
151157

152158
* `Timer`: Fix bug that was causing calls to `reset()` to not reset the timer, if the timer was already being awaited.

src/frequenz/channels/_receiver.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@
140140

141141
from ._exceptions import Error
142142

143-
_T = TypeVar("_T")
144-
_U = TypeVar("_U")
143+
_T_co = TypeVar("_T_co", covariant=True)
144+
_U_co = TypeVar("_U_co", covariant=True)
145145

146146

147-
class Receiver(ABC, Generic[_T]):
147+
class Receiver(ABC, Generic[_T_co]):
148148
"""An endpoint to receive messages."""
149149

150-
async def __anext__(self) -> _T:
150+
async def __anext__(self) -> _T_co:
151151
"""Await the next value in the async iteration over received values.
152152
153153
Returns:
@@ -177,7 +177,7 @@ async def ready(self) -> bool:
177177
"""
178178

179179
@abstractmethod
180-
def consume(self) -> _T:
180+
def consume(self) -> _T_co:
181181
"""Return the latest value once `ready()` is complete.
182182
183183
`ready()` must be called before each call to `consume()`.
@@ -198,7 +198,7 @@ def __aiter__(self) -> Self:
198198
"""
199199
return self
200200

201-
async def receive(self) -> _T:
201+
async def receive(self) -> _T_co:
202202
"""Receive a message from the channel.
203203
204204
Returns:
@@ -225,7 +225,7 @@ async def receive(self) -> _T:
225225
raise ReceiverStoppedError(self) from exc
226226
return received
227227

228-
def map(self, call: Callable[[_T], _U]) -> Receiver[_U]:
228+
def map(self, call: Callable[[_T_co], _U_co]) -> Receiver[_U_co]:
229229
"""Return a receiver with `call` applied on incoming messages.
230230
231231
Args:
@@ -237,13 +237,13 @@ def map(self, call: Callable[[_T], _U]) -> Receiver[_U]:
237237
return _Map(self, call)
238238

239239

240-
class ReceiverError(Error, Generic[_T]):
240+
class ReceiverError(Error, Generic[_T_co]):
241241
"""An error produced in a [Receiver][frequenz.channels.Receiver].
242242
243243
All exceptions generated by receivers inherit from this exception.
244244
"""
245245

246-
def __init__(self, message: str, receiver: Receiver[_T]):
246+
def __init__(self, message: str, receiver: Receiver[_T_co]):
247247
"""Create an instance.
248248
249249
Args:
@@ -252,14 +252,14 @@ def __init__(self, message: str, receiver: Receiver[_T]):
252252
error happened.
253253
"""
254254
super().__init__(message)
255-
self.receiver: Receiver[_T] = receiver
255+
self.receiver: Receiver[_T_co] = receiver
256256
"""The receiver where the error happened."""
257257

258258

259-
class ReceiverStoppedError(ReceiverError[_T]):
259+
class ReceiverStoppedError(ReceiverError[_T_co]):
260260
"""The [Receiver][frequenz.channels.Receiver] stopped producing messages."""
261261

262-
def __init__(self, receiver: Receiver[_T]):
262+
def __init__(self, receiver: Receiver[_T_co]):
263263
"""Create an instance.
264264
265265
Args:
@@ -269,7 +269,7 @@ def __init__(self, receiver: Receiver[_T]):
269269
super().__init__(f"Receiver {receiver} was stopped", receiver)
270270

271271

272-
class _Map(Receiver[_U], Generic[_T, _U]):
272+
class _Map(Receiver[_U_co], Generic[_T_co, _U_co]):
273273
"""Apply a transform function on a channel receiver.
274274
275275
Has two generic types:
@@ -278,17 +278,19 @@ class _Map(Receiver[_U], Generic[_T, _U]):
278278
- The output type: return type of the transform method.
279279
"""
280280

281-
def __init__(self, receiver: Receiver[_T], transform: Callable[[_T], _U]) -> None:
281+
def __init__(
282+
self, receiver: Receiver[_T_co], transform: Callable[[_T_co], _U_co]
283+
) -> None:
282284
"""Create a `Transform` instance.
283285
284286
Args:
285287
receiver: The input receiver.
286288
transform: The function to run on the input data.
287289
"""
288-
self._receiver: Receiver[_T] = receiver
290+
self._receiver: Receiver[_T_co] = receiver
289291
"""The input receiver."""
290292

291-
self._transform: Callable[[_T], _U] = transform
293+
self._transform: Callable[[_T_co], _U_co] = transform
292294
"""The function to run on the input data."""
293295

294296
async def ready(self) -> bool:
@@ -306,7 +308,7 @@ async def ready(self) -> bool:
306308

307309
# We need a noqa here because the docs have a Raises section but the code doesn't
308310
# explicitly raise anything.
309-
def consume(self) -> _U: # noqa: DOC502
311+
def consume(self) -> _U_co: # noqa: DOC502
310312
"""Return a transformed value once `ready()` is complete.
311313
312314
Returns:

src/frequenz/channels/_sender.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@
5454

5555
from ._exceptions import Error
5656

57-
_T = TypeVar("_T")
57+
_T_contra = TypeVar("_T_contra", contravariant=True)
5858

5959

60-
class Sender(ABC, Generic[_T]):
60+
class Sender(ABC, Generic[_T_contra]):
6161
"""An endpoint to sends messages."""
6262

6363
@abstractmethod
64-
async def send(self, msg: _T) -> None:
64+
async def send(self, msg: _T_contra) -> None:
6565
"""Send a message to the channel.
6666
6767
Args:
@@ -72,13 +72,13 @@ async def send(self, msg: _T) -> None:
7272
"""
7373

7474

75-
class SenderError(Error, Generic[_T]):
75+
class SenderError(Error, Generic[_T_contra]):
7676
"""An error produced in a [Sender][frequenz.channels.Sender].
7777
7878
All exceptions generated by senders inherit from this exception.
7979
"""
8080

81-
def __init__(self, message: str, sender: Sender[_T]):
81+
def __init__(self, message: str, sender: Sender[_T_contra]):
8282
"""Create an instance.
8383
8484
Args:
@@ -87,5 +87,5 @@ def __init__(self, message: str, sender: Sender[_T]):
8787
happened.
8888
"""
8989
super().__init__(message)
90-
self.sender: Sender[_T] = sender
90+
self.sender: Sender[_T_contra] = sender
9191
"""The sender where the error happened."""

tests/test_broadcast.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
import asyncio
8+
from dataclasses import dataclass
89

910
import pytest
1011

@@ -252,3 +253,27 @@ async def test_broadcast_receiver_drop() -> None:
252253

253254
assert len(chan._receivers) == 1
254255
# pylint: enable=protected-access
256+
257+
258+
async def test_type_variance() -> None:
259+
"""Ensure that the type variance of Broadcast is working."""
260+
261+
@dataclass
262+
class Broader:
263+
"""A broad class."""
264+
265+
value: int
266+
267+
class Actual(Broader):
268+
"""Actual class."""
269+
270+
class Narrower(Actual):
271+
"""A narrower class."""
272+
273+
chan = Broadcast[Actual](name="input-chan")
274+
275+
sender: Sender[Narrower] = chan.new_sender()
276+
receiver: Receiver[Broader] = chan.new_receiver()
277+
278+
await sender.send(Narrower(10))
279+
assert (await receiver.receive()).value == 10

0 commit comments

Comments
 (0)