66from __future__ import annotations
77
88from abc import ABC , abstractmethod
9- from typing import Callable , Generic , Optional , TypeVar
9+ from typing import Any , Callable , Generic , Optional , TypeVar
1010
1111T = TypeVar ("T" )
1212U = TypeVar ("U" )
1313
1414
15+ class ChannelError (RuntimeError ):
16+ """Base channel error.
17+
18+ All exceptions generated by channels inherit from this exception.
19+ """
20+
21+ def __init__ (self , message : Any , channel : Any = None ):
22+ """Create a ChannelError instance.
23+
24+ Args:
25+ message: An error message.
26+ channel: A reference to the channel that encountered the error.
27+ """
28+ super ().__init__ (message )
29+ self .channel : Any = channel
30+
31+
32+ class ChannelClosedError (ChannelError ):
33+ """Error raised when trying to operate on a closed channel."""
34+
35+ def __init__ (self , channel : Any = None ):
36+ """Create a `ChannelClosedError` instance.
37+
38+ Args:
39+ channel: A reference to the channel that was closed.
40+ """
41+ super ().__init__ (f"Channel { channel } was closed" , channel )
42+
43+
1544class Sender (ABC , Generic [T ]):
1645 """A channel Sender."""
1746
@@ -31,12 +60,43 @@ async def send(self, msg: T) -> bool:
3160class Receiver (ABC , Generic [T ]):
3261 """A channel Receiver."""
3362
63+ async def __anext__ (self ) -> T :
64+ """Await the next value in the async iteration over received values.
65+
66+ Returns:
67+ The next value received.
68+
69+ Raises:
70+ StopAsyncIteration: if the underlying channel is closed.
71+ """
72+ try :
73+ await self .ready ()
74+ return self .consume ()
75+ except ChannelClosedError as exc :
76+ raise StopAsyncIteration () from exc
77+
3478 @abstractmethod
35- async def receive (self ) -> Optional [T ]:
36- """Receive a message from the channel.
79+ async def ready (self ) -> None :
80+ """Wait until the receiver is ready with a value.
81+
82+ Once a call to `ready()` has finished, the value should be read with a call to
83+ `consume()`.
84+
85+ Raises:
86+ ChannelClosedError: if the underlying channel is closed.
87+ """
88+
89+ @abstractmethod
90+ def consume (self ) -> T :
91+ """Return the latest value once `ready()` is complete.
92+
93+ `ready()` must be called before each call to `consume()`.
3794
3895 Returns:
39- `None`, if the channel is closed, a message otherwise.
96+ The next value received.
97+
98+ Raises:
99+ ChannelClosedError: if the underlying channel is closed.
40100 """
41101
42102 def __aiter__ (self ) -> Receiver [T ]:
@@ -47,19 +107,19 @@ def __aiter__(self) -> Receiver[T]:
47107 """
48108 return self
49109
50- async def __anext__ (self ) -> T :
51- """Await the next value in the async iteration over received values.
52-
53- Returns:
54- The next value received.
110+ async def receive (self ) -> T :
111+ """Receive a message from the channel.
55112
56113 Raises:
57- StopAsyncIteration: if we receive `None`, i.e. if the underlying
58- channel is closed.
114+ ChannelClosedError: if the underlying channel is closed.
115+
116+ Returns:
117+ The received message.
59118 """
60- received = await self .receive ()
61- if received is None :
62- raise StopAsyncIteration
119+ try :
120+ received = await self .__anext__ () # pylint: disable=unnecessary-dunder-call
121+ except StopAsyncIteration as exc :
122+ raise ChannelClosedError () from exc
63123 return received
64124
65125 def map (self , call : Callable [[T ], U ]) -> Receiver [U ]:
@@ -136,13 +196,14 @@ def __init__(self, recv: Receiver[T], transform: Callable[[T], U]) -> None:
136196 self ._recv = recv
137197 self ._transform = transform
138198
139- async def receive (self ) -> Optional [U ]:
140- """Return a transformed message received from the input channel.
199+ async def ready (self ) -> None :
200+ """Wait until the receiver is ready with a value."""
201+ await self ._recv .ready () # pylint: disable=protected-access
202+
203+ def consume (self ) -> U :
204+ """Return a transformed value once `ready()` is complete.
141205
142206 Returns:
143- `None`, if the channel is closed, a message otherwise .
207+ The next value that was received .
144208 """
145- msg = await self ._recv .receive ()
146- if msg is None :
147- return None
148- return self ._transform (msg )
209+ return self ._transform (self ._recv .consume ()) # pylint: disable=protected-access
0 commit comments