|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | from abc import ABC, abstractmethod |
9 | | -from typing import Callable, Generic, Optional, TypeVar |
| 9 | +from typing import Any, Callable, Generic, Optional, TypeVar |
10 | 10 |
|
11 | 11 | T = TypeVar("T") |
12 | 12 | U = TypeVar("U") |
13 | 13 |
|
14 | 14 |
|
| 15 | +class ChannelError(RuntimeError): |
| 16 | + """Base channel error. |
| 17 | +
|
| 18 | + All exceptions generated by channels inherit from this exception. |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, message: Any, channel: Any = None): |
| 22 | + """Create a ChannelError instance. |
| 23 | +
|
| 24 | + Args: |
| 25 | + message: An error message. |
| 26 | + channel: A reference to the channel that encountered the error. |
| 27 | + """ |
| 28 | + super().__init__(message) |
| 29 | + self.channel: Any = channel |
| 30 | + |
| 31 | + |
| 32 | +class ChannelClosedError(ChannelError): |
| 33 | + """Error raised when trying to operate on a closed channel.""" |
| 34 | + |
| 35 | + def __init__(self, channel: Any = None): |
| 36 | + """Create a `ChannelClosedError` instance. |
| 37 | +
|
| 38 | + Args: |
| 39 | + channel: A reference to the channel that was closed. |
| 40 | + """ |
| 41 | + super().__init__(f"Channel {channel} was closed", channel) |
| 42 | + |
| 43 | + |
15 | 44 | class Sender(ABC, Generic[T]): |
16 | 45 | """A channel Sender.""" |
17 | 46 |
|
@@ -50,16 +79,19 @@ def __aiter__(self) -> Receiver[T]: |
50 | 79 | """ |
51 | 80 | return self |
52 | 81 |
|
53 | | - async def receive(self) -> Optional[T]: |
| 82 | + async def receive(self) -> T: |
54 | 83 | """Receive a message from the channel. |
55 | 84 |
|
| 85 | + Raises: |
| 86 | + ChannelClosedError: if the underlying channel is closed. |
| 87 | +
|
56 | 88 | Returns: |
57 | 89 | The received message. |
58 | 90 | """ |
59 | 91 | try: |
60 | 92 | received = await self.__anext__() # pylint: disable=unnecessary-dunder-call |
61 | | - except StopAsyncIteration: |
62 | | - return None |
| 93 | + except StopAsyncIteration as exc: |
| 94 | + raise ChannelClosedError() from exc |
63 | 95 | return received |
64 | 96 |
|
65 | 97 | def map(self, call: Callable[[T], U]) -> Receiver[U]: |
|
0 commit comments