From de81b6beb543c713a6dfaf8ee5c8376ca05a872a Mon Sep 17 00:00:00 2001 From: Sahas Subramanian Date: Fri, 29 Nov 2024 15:16:32 +0100 Subject: [PATCH 1/4] Add `@override` decorators Classes that implement the `Sender` or `Receiver` interfaces currently need to override the `send` or the `ready` and `consume` methods respectively. As we add more methods, to these interfaces, it becomes hard to track which methods are there for internal use and which ones are there to implement the interface. Using the `override` decorators helps with that. Signed-off-by: Sahas Subramanian --- src/frequenz/channels/_anycast.py | 5 +++++ src/frequenz/channels/_broadcast.py | 5 +++++ src/frequenz/channels/_merge.py | 4 ++++ src/frequenz/channels/_receiver.py | 6 ++++++ src/frequenz/channels/event.py | 4 ++++ src/frequenz/channels/experimental/_relay_sender.py | 3 +++ src/frequenz/channels/file_watcher.py | 3 +++ src/frequenz/channels/timer.py | 4 ++++ 8 files changed, 34 insertions(+) diff --git a/src/frequenz/channels/_anycast.py b/src/frequenz/channels/_anycast.py index 1fadbef8..aaa6e6df 100644 --- a/src/frequenz/channels/_anycast.py +++ b/src/frequenz/channels/_anycast.py @@ -10,6 +10,8 @@ from collections import deque from typing import Generic, TypeVar +from typing_extensions import override + from ._exceptions import ChannelClosedError from ._generic import ChannelMessageT from ._receiver import Receiver, ReceiverStoppedError @@ -320,6 +322,7 @@ def __init__(self, channel: Anycast[_T], /) -> None: self._channel: Anycast[_T] = channel """The channel that this sender belongs to.""" + @override async def send(self, message: _T, /) -> None: """Send a message across the channel. @@ -390,6 +393,7 @@ def __init__(self, channel: Anycast[_T], /) -> None: self._next: _T | type[_Empty] = _Empty + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -417,6 +421,7 @@ async def ready(self) -> bool: # pylint: enable=protected-access return True + @override def consume(self) -> _T: """Return the latest message once `ready()` is complete. diff --git a/src/frequenz/channels/_broadcast.py b/src/frequenz/channels/_broadcast.py index c1017cd7..26bcaec7 100644 --- a/src/frequenz/channels/_broadcast.py +++ b/src/frequenz/channels/_broadcast.py @@ -11,6 +11,8 @@ from collections import deque from typing import Generic, TypeVar +from typing_extensions import override + from ._exceptions import ChannelClosedError from ._generic import ChannelMessageT from ._receiver import Receiver, ReceiverStoppedError @@ -327,6 +329,7 @@ def __init__(self, channel: Broadcast[_T], /) -> None: self._channel: Broadcast[_T] = channel """The broadcast channel this sender belongs to.""" + @override async def send(self, message: _T, /) -> None: """Send a message to all broadcast receivers. @@ -441,6 +444,7 @@ def __len__(self) -> int: """ return len(self._q) + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -469,6 +473,7 @@ async def ready(self) -> bool: return True # pylint: enable=protected-access + @override def consume(self) -> _T: """Return the latest message once `ready` is complete. diff --git a/src/frequenz/channels/_merge.py b/src/frequenz/channels/_merge.py index 3a306dbb..479e57ae 100644 --- a/src/frequenz/channels/_merge.py +++ b/src/frequenz/channels/_merge.py @@ -54,6 +54,8 @@ from collections import deque from typing import Any +from typing_extensions import override + from ._generic import ReceiverMessageT_co from ._receiver import Receiver, ReceiverStoppedError @@ -135,6 +137,7 @@ async def stop(self) -> None: await asyncio.gather(*self._pending, return_exceptions=True) self._pending = set() + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -171,6 +174,7 @@ async def ready(self) -> bool: asyncio.create_task(anext(self._receivers[name]), name=name) ) + @override def consume(self) -> ReceiverMessageT_co: """Return the latest message once `ready` is complete. diff --git a/src/frequenz/channels/_receiver.py b/src/frequenz/channels/_receiver.py index 53862a45..e5715f6a 100644 --- a/src/frequenz/channels/_receiver.py +++ b/src/frequenz/channels/_receiver.py @@ -157,6 +157,8 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard, TypeVar, overload +from typing_extensions import override + from ._exceptions import Error from ._generic import MappedMessageT_co, ReceiverMessageT_co @@ -433,6 +435,7 @@ def __init__( ) """The function to apply on the input data.""" + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -448,6 +451,7 @@ async def ready(self) -> bool: # We need a noqa here because the docs have a Raises section but the code doesn't # explicitly raise anything. + @override def consume(self) -> MappedMessageT_co: # noqa: DOC502 """Return a transformed message once `ready()` is complete. @@ -509,6 +513,7 @@ def __init__( self._recv_closed = False + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -528,6 +533,7 @@ async def ready(self) -> bool: self._recv_closed = True return False + @override def consume(self) -> ReceiverMessageT_co: """Return a transformed message once `ready()` is complete. diff --git a/src/frequenz/channels/event.py b/src/frequenz/channels/event.py index 0d599e89..9391c694 100644 --- a/src/frequenz/channels/event.py +++ b/src/frequenz/channels/event.py @@ -16,6 +16,8 @@ import asyncio as _asyncio +from typing_extensions import override + from frequenz.channels._receiver import Receiver, ReceiverStoppedError @@ -141,6 +143,7 @@ def set(self) -> None: self._is_set = True self._event.set() + @override async def ready(self) -> bool: """Wait until this receiver is ready. @@ -152,6 +155,7 @@ async def ready(self) -> bool: await self._event.wait() return not self._is_stopped + @override def consume(self) -> None: """Consume the event. diff --git a/src/frequenz/channels/experimental/_relay_sender.py b/src/frequenz/channels/experimental/_relay_sender.py index 3d78b474..398ba8d5 100644 --- a/src/frequenz/channels/experimental/_relay_sender.py +++ b/src/frequenz/channels/experimental/_relay_sender.py @@ -9,6 +9,8 @@ import typing +from typing_extensions import override + from .._generic import SenderMessageT_contra from .._sender import Sender @@ -46,6 +48,7 @@ def __init__(self, *senders: Sender[SenderMessageT_contra]) -> None: """ self._senders = senders + @override async def send(self, message: SenderMessageT_contra, /) -> None: """Send a message. diff --git a/src/frequenz/channels/file_watcher.py b/src/frequenz/channels/file_watcher.py index a4dc3074..ac66e95e 100644 --- a/src/frequenz/channels/file_watcher.py +++ b/src/frequenz/channels/file_watcher.py @@ -25,6 +25,7 @@ from datetime import timedelta from enum import Enum +from typing_extensions import override from watchfiles import Change, awatch from watchfiles.main import FileChange @@ -185,6 +186,7 @@ def __del__(self) -> None: # is stopped. self._stop_event.set() + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -212,6 +214,7 @@ async def ready(self) -> bool: return True + @override def consume(self) -> Event: """Return the latest event once `ready` is complete. diff --git a/src/frequenz/channels/timer.py b/src/frequenz/channels/timer.py index 2785feea..24531c93 100644 --- a/src/frequenz/channels/timer.py +++ b/src/frequenz/channels/timer.py @@ -102,6 +102,8 @@ async def main() -> None: import asyncio from datetime import timedelta +from typing_extensions import override + from ._receiver import Receiver, ReceiverStoppedError @@ -644,6 +646,7 @@ def stop(self) -> None: # We need a noqa here because the docs have a Raises section but the documented # exceptions are raised indirectly. + @override async def ready(self) -> bool: # noqa: DOC502 """Wait until the timer `interval` passed. @@ -715,6 +718,7 @@ async def ready(self) -> bool: # noqa: DOC502 return True + @override def consume(self) -> timedelta: """Return the latest drift once `ready()` is complete. From cea6cfcfd1a7d1754c0bc051e8caf65f8ec36573 Mon Sep 17 00:00:00 2001 From: Sahas Subramanian Date: Fri, 29 Nov 2024 15:54:43 +0100 Subject: [PATCH 2/4] Add a `close` method to the `Receiver` interface Also implement the method in all classes implementing the `Receiver` interface. Signed-off-by: Sahas Subramanian --- src/frequenz/channels/_anycast.py | 17 ++++++++++++++++ src/frequenz/channels/_broadcast.py | 21 ++++++++++++++++++- src/frequenz/channels/_merge.py | 15 ++++++++++++++ src/frequenz/channels/_receiver.py | 29 +++++++++++++++++++++++++++ src/frequenz/channels/event.py | 5 +++++ src/frequenz/channels/file_watcher.py | 5 +++++ src/frequenz/channels/timer.py | 5 +++++ 7 files changed, 96 insertions(+), 1 deletion(-) diff --git a/src/frequenz/channels/_anycast.py b/src/frequenz/channels/_anycast.py index aaa6e6df..a3d2a846 100644 --- a/src/frequenz/channels/_anycast.py +++ b/src/frequenz/channels/_anycast.py @@ -391,6 +391,9 @@ def __init__(self, channel: Anycast[_T], /) -> None: self._channel: Anycast[_T] = channel """The channel that this receiver belongs to.""" + self._closed: bool = False + """Whether the receiver is closed.""" + self._next: _T | type[_Empty] = _Empty @override @@ -409,6 +412,9 @@ async def ready(self) -> bool: if self._next is not _Empty: return True + if self._closed: + return False + # pylint: disable=protected-access while len(self._channel._deque) == 0: if self._channel._closed: @@ -436,6 +442,9 @@ def consume(self) -> _T: ): raise ReceiverStoppedError(self) from ChannelClosedError(self._channel) + if self._next is _Empty and self._closed: + raise ReceiverStoppedError(self) + assert ( self._next is not _Empty ), "`consume()` must be preceded by a call to `ready()`" @@ -446,6 +455,14 @@ def consume(self) -> _T: return next_val + @override + def close(self) -> None: + """Close this receiver. + + After closing, the receiver will not be able to receive any more messages. + """ + self._closed = True + def __str__(self) -> str: """Return a string representation of this receiver.""" return f"{self._channel}:{type(self).__name__}" diff --git a/src/frequenz/channels/_broadcast.py b/src/frequenz/channels/_broadcast.py index 26bcaec7..cd31a9f4 100644 --- a/src/frequenz/channels/_broadcast.py +++ b/src/frequenz/channels/_broadcast.py @@ -417,6 +417,9 @@ def __init__( self._q: deque[_T] = deque(maxlen=limit) """The receiver's internal message queue.""" + self._closed: bool = False + """Whether the receiver is closed.""" + def enqueue(self, message: _T, /) -> None: """Put a message into this receiver's queue. @@ -466,7 +469,7 @@ async def ready(self) -> bool: # consumed, then we return immediately. # pylint: disable=protected-access while len(self._q) == 0: - if self._channel._closed: + if self._channel._closed or self._closed: return False async with self._channel._recv_cv: await self._channel._recv_cv.wait() @@ -486,9 +489,25 @@ def consume(self) -> _T: if not self._q and self._channel._closed: # pylint: disable=protected-access raise ReceiverStoppedError(self) from ChannelClosedError(self._channel) + if self._closed: + raise ReceiverStoppedError(self) + assert self._q, "`consume()` must be preceded by a call to `ready()`" return self._q.popleft() + @override + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be received. Once the + receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + self._closed = True + self._channel._receivers.pop( # pylint: disable=protected-access + hash(self), None + ) + def __str__(self) -> str: """Return a string representation of this receiver.""" return f"{self._channel}:{type(self).__name__}" diff --git a/src/frequenz/channels/_merge.py b/src/frequenz/channels/_merge.py index 479e57ae..b1f38857 100644 --- a/src/frequenz/channels/_merge.py +++ b/src/frequenz/channels/_merge.py @@ -191,6 +191,21 @@ def consume(self) -> ReceiverMessageT_co: return self._results.popleft() + @override + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be received. Once the + receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + for task in self._pending: + if not task.done() and task.get_loop().is_running(): + task.cancel() + self._pending = set() + for recv in self._receivers.values(): + recv.close() + def __str__(self) -> str: """Return a string representation of this receiver.""" if len(self._receivers) > 3: diff --git a/src/frequenz/channels/_receiver.py b/src/frequenz/channels/_receiver.py index e5715f6a..b7d5e306 100644 --- a/src/frequenz/channels/_receiver.py +++ b/src/frequenz/channels/_receiver.py @@ -217,6 +217,15 @@ def consume(self) -> ReceiverMessageT_co: ReceiverError: If there is some problem with the receiver. """ + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be available from the receiver. + Once the receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + raise NotImplementedError("close() must be implemented by subclasses") + def __aiter__(self) -> Self: """Get an async iterator over the received messages. @@ -464,6 +473,16 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502 """ return self._mapping_function(self._receiver.consume()) + @override + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be received. Once the + receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + self._receiver.close() + def __str__(self) -> str: """Return a string representation of the mapper.""" return f"{type(self).__name__}:{self._receiver}:{self._mapping_function}" @@ -553,6 +572,16 @@ def consume(self) -> ReceiverMessageT_co: self._next_message = _SENTINEL return message + @override + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be received. Once the + receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + self._receiver.close() + def __str__(self) -> str: """Return a string representation of the filter.""" return f"{type(self).__name__}:{self._receiver}:{self._filter_function}" diff --git a/src/frequenz/channels/event.py b/src/frequenz/channels/event.py index 9391c694..5d1bd425 100644 --- a/src/frequenz/channels/event.py +++ b/src/frequenz/channels/event.py @@ -172,6 +172,11 @@ def consume(self) -> None: self._is_set = False self._event.clear() + @override + def close(self) -> None: + """Close this receiver.""" + self.stop() + def __str__(self) -> str: """Return a string representation of this event.""" return f"{type(self).__name__}({self._name!r})" diff --git a/src/frequenz/channels/file_watcher.py b/src/frequenz/channels/file_watcher.py index ac66e95e..e9ff4ca4 100644 --- a/src/frequenz/channels/file_watcher.py +++ b/src/frequenz/channels/file_watcher.py @@ -232,6 +232,11 @@ def consume(self) -> Event: change, path_str = self._changes.pop() return Event(type=EventType(change), path=pathlib.Path(path_str)) + @override + def close(self) -> None: + """Close this receiver.""" + self._stop_event.set() + def __str__(self) -> str: """Return a string representation of this receiver.""" if len(self._paths) > 3: diff --git a/src/frequenz/channels/timer.py b/src/frequenz/channels/timer.py index 24531c93..998430e9 100644 --- a/src/frequenz/channels/timer.py +++ b/src/frequenz/channels/timer.py @@ -745,6 +745,11 @@ def consume(self) -> timedelta: self._current_drift = None return drift + @override + def close(self) -> None: + """Close the timer.""" + self.stop() + def _now(self) -> int: """Return the current monotonic clock time in microseconds. From e0f4b12ca73355c725352b040f3f7f7784b6d63b Mon Sep 17 00:00:00 2001 From: Sahas Subramanian Date: Tue, 7 Jan 2025 11:27:21 +0100 Subject: [PATCH 3/4] Add tests for all `Receiver.close()` implementations Signed-off-by: Sahas Subramanian --- tests/test_anycast.py | 27 +++++++ tests/test_broadcast.py | 105 +++++++++++++++++++++++++ tests/test_event.py | 48 +++++++++++ tests/test_file_watcher_integration.py | 37 +++++++++ tests/test_merge_integration.py | 43 ++++++++++ tests/test_timer.py | 16 ++++ 6 files changed, 276 insertions(+) diff --git a/tests/test_anycast.py b/tests/test_anycast.py index c6db0d9a..918c548c 100644 --- a/tests/test_anycast.py +++ b/tests/test_anycast.py @@ -217,3 +217,30 @@ async def test_anycast_filter() -> None: assert (await receiver.receive()) == 12 assert (await receiver.receive()) == 15 + + +async def test_anycast_close_receiver() -> None: + """Ensure closing a receiver stops the receiver.""" + chan = Anycast[int](name="input-chan") + sender = chan.new_sender() + + receiver_1 = chan.new_receiver() + receiver_2 = chan.new_receiver() + + await sender.send(1) + + assert (await receiver_1.receive()) == 1 + + receiver_1.close() + + await sender.send(2) + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_1.receive() + + assert (await receiver_2.receive()) == 2 + + receiver_2.close() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_2.receive() diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index c8a2e9cf..f995a922 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -232,6 +232,42 @@ async def test_broadcast_map() -> None: assert (await receiver.receive()) is True +async def test_broadcast_map_close_receiver() -> None: + """Ensure closing a map stops the receiver.""" + chan = Broadcast[int](name="input-chan") + sender = chan.new_sender() + + receiver_1 = chan.new_receiver() + receiver_2 = chan.new_receiver() + plus_100_rx = receiver_1.map(lambda num: num + 100) + + await sender.send(1) + + assert (await plus_100_rx.receive()) == 101 + assert (await receiver_2.receive()) == 1 + + plus_100_rx.close() + + await sender.send(2) + + with pytest.raises(ReceiverStoppedError): + _ = await plus_100_rx.receive() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_1.receive() + + assert (await receiver_2.receive()) == 2 + + await sender.send(3) + + assert (await receiver_2.receive()) == 3 + + receiver_2.close() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_2.receive() + + async def test_broadcast_filter() -> None: """Ensure filter keeps only the messages that pass the filter.""" chan = Broadcast[int](name="input-chan") @@ -249,6 +285,43 @@ async def test_broadcast_filter() -> None: assert (await receiver.receive()) == 15 +async def test_broadcast_filter_close_receiver() -> None: + """Ensure closing a filter stops the receiver.""" + chan = Broadcast[int](name="input-chan") + sender = chan.new_sender() + + receiver_1 = chan.new_receiver() + receiver_2 = chan.new_receiver() + + gt_10_rx = receiver_1.filter(lambda num: num > 10) + + await sender.send(1) + assert (await receiver_2.receive()) == 1 + + await sender.send(100) + assert (await gt_10_rx.receive()) == 100 + assert (await receiver_2.receive()) == 100 + + gt_10_rx.close() + + await sender.send(2) + + with pytest.raises(ReceiverStoppedError): + _ = await gt_10_rx.receive() + with pytest.raises(ReceiverStoppedError): + _ = await receiver_1.receive() + + assert (await receiver_2.receive()) == 2 + + await sender.send(3) + assert (await receiver_2.receive()) == 3 + + receiver_2.close() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_2.receive() + + async def test_broadcast_filter_type_guard() -> None: """Ensure filter type guard works.""" chan = Broadcast[int | str](name="input-chan") @@ -320,3 +393,35 @@ class Narrower(Actual): await sender.send(Narrower(10)) assert (await receiver.receive()).value == 10 + + +async def test_broadcast_close_receiver() -> None: + """Ensure closing a receiver stops the receiver.""" + chan = Broadcast[int](name="input-chan") + sender = chan.new_sender() + + receiver_1 = chan.new_receiver() + receiver_2 = chan.new_receiver() + + await sender.send(1) + + assert (await receiver_1.receive()) == 1 + assert (await receiver_2.receive()) == 1 + + receiver_1.close() + + await sender.send(2) + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_1.receive() + + assert (await receiver_2.receive()) == 2 + + await sender.send(3) + + assert (await receiver_2.receive()) == 3 + + receiver_2.close() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_2.receive() diff --git a/tests/test_event.py b/tests/test_event.py index 950720d0..c9e061c5 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -57,3 +57,51 @@ async def wait_for_event() -> None: assert not event.is_set await event_task + + +async def test_event_close_receiver() -> None: + """Ensure that closing an event stops the receiver.""" + event = Event() + assert not event.is_set + assert not event.is_stopped + + is_ready = False + + async def wait_for_event() -> None: + nonlocal is_ready + await event.ready() + is_ready = True + + event_task = _asyncio.create_task(wait_for_event()) + + await _asyncio.sleep(0) # Yield so the wait_for_event task can run. + + assert not is_ready + assert not event.is_set + assert not event.is_stopped + + event.set() + + await _asyncio.sleep(0) # Yield so the wait_for_event task can run. + assert is_ready + assert event.is_set + assert not event.is_stopped + + event.consume() + assert not event.is_set + assert not event.is_stopped + assert event_task.done() + assert event_task.result() is None + assert not event_task.cancelled() + + event.close() + assert not event.is_set + assert event.is_stopped + + await event.ready() + with _pytest.raises(ReceiverStoppedError): + event.consume() + assert event.is_stopped + assert not event.is_set + + await event_task diff --git a/tests/test_file_watcher_integration.py b/tests/test_file_watcher_integration.py index ef1846e6..7d727aa2 100644 --- a/tests/test_file_watcher_integration.py +++ b/tests/test_file_watcher_integration.py @@ -150,3 +150,40 @@ async def test_file_watcher_exit_iterator(tmp_path: pathlib.Path) -> None: file_watcher.consume() assert number_of_writes == expected_number_of_writes + + +@pytest.mark.integration +async def test_file_watcher_close_receiver(tmp_path: pathlib.Path) -> None: + """Ensure closing the file watcher stops the receiver. + + Args: + tmp_path: A tmp directory to run the file watcher on. Created by pytest. + """ + filename = tmp_path / "test-file" + + number_of_writes = 0 + expected_number_of_writes = 3 + + file_watcher = FileWatcher( + paths=[str(tmp_path)], + force_polling=True, + polling_interval=timedelta(seconds=0.05), + ) + timer = Timer(timedelta(seconds=0.1), SkipMissedAndDrift()) + + async for selected in select(file_watcher, timer): + if selected_from(selected, timer): + filename.write_text(f"{selected.message}") + elif selected_from(selected, file_watcher): + number_of_writes += 1 + if number_of_writes == expected_number_of_writes: + file_watcher.close() + break + + ready = await file_watcher.ready() + assert ready is False + + with pytest.raises(ReceiverStoppedError): + file_watcher.consume() + + assert number_of_writes == expected_number_of_writes diff --git a/tests/test_merge_integration.py b/tests/test_merge_integration.py index 2e3eefcb..f5a61970 100644 --- a/tests/test_merge_integration.py +++ b/tests/test_merge_integration.py @@ -8,6 +8,8 @@ import pytest from frequenz.channels import Anycast, Sender, merge +from frequenz.channels._broadcast import Broadcast +from frequenz.channels._receiver import ReceiverStoppedError @pytest.mark.integration @@ -39,3 +41,44 @@ async def send(ch1: Sender[int], ch2: Sender[int]) -> None: # succession. assert set(results[idx : idx + 2]) == {ctr + 1, ctr + 101} assert results[-1] == 1000 + + +async def test_merge_close_receiver() -> None: + """Ensure merge() closes when a receiver is closed.""" + chan1 = Broadcast[int](name="chan1") + chan2 = Broadcast[int](name="chan2") + + async def send(ch1: Sender[int], ch2: Sender[int]) -> None: + for ctr in range(5): + await ch1.send(ctr + 1) + await ch2.send(ctr + 101) + await chan1.close() + await chan2.close() + + rx1 = chan1.new_receiver() + rx2 = chan2.new_receiver() + closing_merge = merge(rx1, rx2) + prx1 = chan1.new_receiver() + prx2 = chan2.new_receiver() + completing_merge = merge(prx1, prx2) + + senders = asyncio.create_task(send(chan1.new_sender(), chan2.new_sender())) + + results: list[int] = [] + async for item in closing_merge: + results.append(item) + if item == 3: + closing_merge.close() + await senders + assert set(results) == {1, 101, 2, 102, 3, 103} + + with pytest.raises(ReceiverStoppedError): + _ = await rx1.receive() + + with pytest.raises(ReceiverStoppedError): + _ = await rx2.receive() + + comp_results: set[int] = set() + async for item in completing_merge: + comp_results.add(item) + assert comp_results == {1, 101, 2, 102, 3, 103, 4, 104, 5, 105} diff --git a/tests/test_timer.py b/tests/test_timer.py index 1e1dc7ed..55295b29 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -13,6 +13,7 @@ import pytest from hypothesis import strategies as st +from frequenz.channels import ReceiverStoppedError from frequenz.channels.timer import ( SkipMissedAndDrift, SkipMissedAndResync, @@ -331,6 +332,21 @@ async def test_timer_construction_wrong_args() -> None: ) +async def test_timer_close_receiver() -> None: + """Test the autostart of a periodic timer.""" + event_loop = asyncio.get_running_loop() + + timer = Timer(timedelta(seconds=1.0), TriggerAllMissed()) + + drift = await timer.receive() + assert drift == pytest.approx(timedelta(seconds=0.0)) + assert event_loop.time() == pytest.approx(1.0) + + timer.close() + with pytest.raises(ReceiverStoppedError): + await timer.receive() + + async def test_timer_autostart() -> None: """Test the autostart of a periodic timer.""" event_loop = asyncio.get_running_loop() From 6e5d6356de2c9a657d6eb5ac826b0b42be4db7ca Mon Sep 17 00:00:00 2001 From: Sahas Subramanian Date: Tue, 7 Jan 2025 11:35:19 +0100 Subject: [PATCH 4/4] Update release notes Signed-off-by: Sahas Subramanian --- RELEASE_NOTES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 96a0240b..ffa1c322 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,7 +10,7 @@ ## New Features - +- Added a `Receiver.close()` method for closing just a receiver. Also implemented it for all the `Receiver` implementations in this library. ## Bug Fixes