Skip to content

Commit 5cac8af

Browse files
committed
Use covariant generic types for Receivers
This allows broadening of the message type, such that receivers that stream narrow types can be typed as receivers streaming broader types. Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 977002c commit 5cac8af

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

src/frequenz/channels/_receiver.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@
140140

141141
from ._exceptions import Error
142142

143-
_T = TypeVar("_T")
144-
_U = TypeVar("_U")
143+
_T_co = TypeVar("_T_co", covariant=True)
144+
_U_co = TypeVar("_U_co", covariant=True)
145145

146146

147-
class Receiver(ABC, Generic[_T]):
147+
class Receiver(ABC, Generic[_T_co]):
148148
"""An endpoint to receive messages."""
149149

150-
async def __anext__(self) -> _T:
150+
async def __anext__(self) -> _T_co:
151151
"""Await the next value in the async iteration over received values.
152152
153153
Returns:
@@ -177,7 +177,7 @@ async def ready(self) -> bool:
177177
"""
178178

179179
@abstractmethod
180-
def consume(self) -> _T:
180+
def consume(self) -> _T_co:
181181
"""Return the latest value once `ready()` is complete.
182182
183183
`ready()` must be called before each call to `consume()`.
@@ -198,7 +198,7 @@ def __aiter__(self) -> Self:
198198
"""
199199
return self
200200

201-
async def receive(self) -> _T:
201+
async def receive(self) -> _T_co:
202202
"""Receive a message from the channel.
203203
204204
Returns:
@@ -225,7 +225,7 @@ async def receive(self) -> _T:
225225
raise ReceiverStoppedError(self) from exc
226226
return received
227227

228-
def map(self, call: Callable[[_T], _U]) -> Receiver[_U]:
228+
def map(self, call: Callable[[_T_co], _U_co]) -> Receiver[_U_co]:
229229
"""Return a receiver with `call` applied on incoming messages.
230230
231231
Args:
@@ -237,13 +237,13 @@ def map(self, call: Callable[[_T], _U]) -> Receiver[_U]:
237237
return _Map(self, call)
238238

239239

240-
class ReceiverError(Error, Generic[_T]):
240+
class ReceiverError(Error, Generic[_T_co]):
241241
"""An error produced in a [Receiver][frequenz.channels.Receiver].
242242
243243
All exceptions generated by receivers inherit from this exception.
244244
"""
245245

246-
def __init__(self, message: str, receiver: Receiver[_T]):
246+
def __init__(self, message: str, receiver: Receiver[_T_co]):
247247
"""Create an instance.
248248
249249
Args:
@@ -252,14 +252,14 @@ def __init__(self, message: str, receiver: Receiver[_T]):
252252
error happened.
253253
"""
254254
super().__init__(message)
255-
self.receiver: Receiver[_T] = receiver
255+
self.receiver: Receiver[_T_co] = receiver
256256
"""The receiver where the error happened."""
257257

258258

259-
class ReceiverStoppedError(ReceiverError[_T]):
259+
class ReceiverStoppedError(ReceiverError[_T_co]):
260260
"""The [Receiver][frequenz.channels.Receiver] stopped producing messages."""
261261

262-
def __init__(self, receiver: Receiver[_T]):
262+
def __init__(self, receiver: Receiver[_T_co]):
263263
"""Create an instance.
264264
265265
Args:
@@ -269,7 +269,7 @@ def __init__(self, receiver: Receiver[_T]):
269269
super().__init__(f"Receiver {receiver} was stopped", receiver)
270270

271271

272-
class _Map(Receiver[_U], Generic[_T, _U]):
272+
class _Map(Receiver[_U_co], Generic[_T_co, _U_co]):
273273
"""Apply a transform function on a channel receiver.
274274
275275
Has two generic types:
@@ -278,17 +278,19 @@ class _Map(Receiver[_U], Generic[_T, _U]):
278278
- The output type: return type of the transform method.
279279
"""
280280

281-
def __init__(self, receiver: Receiver[_T], transform: Callable[[_T], _U]) -> None:
281+
def __init__(
282+
self, receiver: Receiver[_T_co], transform: Callable[[_T_co], _U_co]
283+
) -> None:
282284
"""Create a `Transform` instance.
283285
284286
Args:
285287
receiver: The input receiver.
286288
transform: The function to run on the input data.
287289
"""
288-
self._receiver: Receiver[_T] = receiver
290+
self._receiver: Receiver[_T_co] = receiver
289291
"""The input receiver."""
290292

291-
self._transform: Callable[[_T], _U] = transform
293+
self._transform: Callable[[_T_co], _U_co] = transform
292294
"""The function to run on the input data."""
293295

294296
async def ready(self) -> bool:
@@ -306,7 +308,7 @@ async def ready(self) -> bool:
306308

307309
# We need a noqa here because the docs have a Raises section but the code doesn't
308310
# explicitly raise anything.
309-
def consume(self) -> _U: # noqa: DOC502
311+
def consume(self) -> _U_co: # noqa: DOC502
310312
"""Return a transformed value once `ready()` is complete.
311313
312314
Returns:

0 commit comments

Comments
 (0)