diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index ef8bce74..d683396d 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -11,6 +11,9 @@ ## New Features - An optional `tick_at_start` parameter has been added to `Timer`. When `True`, the timer will trigger immediately after starting, and then wait for the interval before triggering again. +- Add `Receiver.fork` method to create independent clones of the receiver. + - Useful for scenarios where multiple consumers need to process the same stream of messages. Each forked receiver. + - Each forked receiver maintains its own independent message queue ## Bug Fixes diff --git a/pyproject.toml b/pyproject.toml index e7e768ff..3c940a93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] requires-python = ">= 3.11, < 4" dependencies = [ - "typing-extensions >= 4.5.0, < 5", + "typing-extensions >= 4.11.0, < 5", "watchfiles >= 0.15.0, < 1.1.0", ] dynamic = ["version"] @@ -39,7 +39,7 @@ email = "floss@frequenz.com" dev-flake8 = [ "flake8 == 7.1.1", "flake8-docstrings == 1.7.0", - "flake8-pyproject == 1.2.3", # For reading the flake8 config from pyproject.toml + "flake8-pyproject == 1.2.3", # For reading the flake8 config from pyproject.toml "pydoclint == 0.6.0", "pydocstyle == 6.3.0", ] diff --git a/src/frequenz/channels/_anycast.py b/src/frequenz/channels/_anycast.py index a3d2a846..60d7b54b 100644 --- a/src/frequenz/channels/_anycast.py +++ b/src/frequenz/channels/_anycast.py @@ -463,6 +463,25 @@ def close(self) -> None: """ self._closed = True + @override + def fork(self, *, name: str | None = None) -> "Receiver[_T]": + """Create a new receiver that is a clone of this receiver. + + Args: + name: An optional name for the new receiver. This is ignored as Anycast + receivers don't have names. + + Returns: + A new receiver that is a clone of this receiver. + + Raises: + ReceiverStoppedError: If the receiver is closed. + """ + if self._closed: + raise ReceiverStoppedError(self) + + return self._channel.new_receiver() + def __str__(self) -> str: """Return a string representation of this receiver.""" return f"{self._channel}:{type(self).__name__}" diff --git a/src/frequenz/channels/_broadcast.py b/src/frequenz/channels/_broadcast.py index cd31a9f4..9625ed77 100644 --- a/src/frequenz/channels/_broadcast.py +++ b/src/frequenz/channels/_broadcast.py @@ -508,6 +508,33 @@ def close(self) -> None: hash(self), None ) + @override + def fork(self, *, name: str | None = None) -> "Receiver[_T]": + """Create a new receiver that is a clone of this receiver. + + Args: + name: An optional name for the new receiver. If None, a new name will be + generated based on the receiver's id. + + Returns: + A new receiver that is a clone of this receiver. + + Raises: + ReceiverStoppedError: If the receiver is closed. + """ + if self._closed: + raise ReceiverStoppedError(self) + + limit = self._q.maxlen + assert limit is not None + + fork_name = name if name is not None else None + + # Create a new receiver with the same configuration + return self._channel.new_receiver( + name=fork_name, limit=limit, warn_on_overflow=self._warn_on_overflow + ) + def __str__(self) -> str: """Return a string representation of this receiver.""" return f"{self._channel}:{type(self).__name__}" diff --git a/src/frequenz/channels/_merge.py b/src/frequenz/channels/_merge.py index b1f38857..bb09bdc0 100644 --- a/src/frequenz/channels/_merge.py +++ b/src/frequenz/channels/_merge.py @@ -206,6 +206,35 @@ def close(self) -> None: for recv in self._receivers.values(): recv.close() + @override + def fork(self, *, name: str | None = None) -> "Merger[ReceiverMessageT_co]": + """Create a new receiver that is a clone of this receiver. + + Args: + name: An optional name for the new receiver. If None, the same naming + approach as the original merger will be used. + + Returns: + A new receiver that is a clone of this receiver. + """ + # Fork all the underlying not stopped receivers + + forked_receivers: list[Receiver[ReceiverMessageT_co]] = [] + for recv_name, recv in self._receivers.items(): + # Don't fork stopped receivers + try: + forked = recv.fork(name=recv_name) + except ReceiverStoppedError: + continue + else: + forked_receivers.append(forked) + + # Use the provided name or the same approach as original + fork_name = name if name is not None else self._name + + # Create a new merger with the forked receivers + return Merger(*forked_receivers, name=fork_name) + def __str__(self) -> str: """Return a string representation of this receiver.""" if len(self._receivers) > 3: diff --git a/src/frequenz/channels/_receiver.py b/src/frequenz/channels/_receiver.py index 57876d7b..e34cc4b2 100644 --- a/src/frequenz/channels/_receiver.py +++ b/src/frequenz/channels/_receiver.py @@ -239,6 +239,20 @@ def close(self) -> None: """ raise NotImplementedError("close() must be implemented by subclasses") + @abstractmethod + def fork(self, *, name: str | None = None) -> "Receiver[ReceiverMessageT_co]": + """Create a new receiver that is a clone of this receiver. + + Args: + name: An optional name for the new receiver. + + Returns: + A new receiver that is a clone of this receiver. + + Raises: + ReceiverStoppedError: If the receiver is stopped. + """ + def __aiter__(self) -> Self: """Get an async iterator over the received messages. @@ -496,6 +510,24 @@ def close(self) -> None: """ self._receiver.close() + @override + def fork( + self, *, name: str | None = None + ) -> "_Mapper[ReceiverMessageT_co, MappedMessageT_co]": + """Create a new receiver that is a clone of this receiver. + + Args: + name: An optional name for the new receiver. This is ignored since mapper + receivers don't have names. + + Returns: + A new receiver that is a clone of this receiver. + """ + return _Mapper( + receiver=self._receiver.fork(name=name), + mapping_function=self._mapping_function, + ) + def __str__(self) -> str: """Return a string representation of the mapper.""" return f"{type(self).__name__}:{self._receiver}:{self._mapping_function}" @@ -573,7 +605,7 @@ def consume(self) -> ReceiverMessageT_co: The next message that was received. Raises: - ReceiverStoppedError: If the receiver stopped producing messages. + ReceiverStoppedError: If the receiver is stopped. """ if self._recv_closed: raise ReceiverStoppedError(self) @@ -595,6 +627,22 @@ def close(self) -> None: """ self._receiver.close() + @override + def fork(self, *, name: str | None = None) -> "_Filter[ReceiverMessageT_co]": + """Create a new receiver that is a clone of this receiver. + + Args: + name: An optional name for the new receiver. This is ignored since filter + receivers don't have names. + + Returns: + A new receiver that is a clone of this receiver. + """ + return _Filter( + receiver=self._receiver.fork(name=name), + filter_function=self._filter_function, + ) + def __str__(self) -> str: """Return a string representation of the filter.""" return f"{type(self).__name__}:{self._receiver}:{self._filter_function}" diff --git a/src/frequenz/channels/event.py b/src/frequenz/channels/event.py index 5d1bd425..fa665f86 100644 --- a/src/frequenz/channels/event.py +++ b/src/frequenz/channels/event.py @@ -177,6 +177,25 @@ def close(self) -> None: """Close this receiver.""" self.stop() + @override + def fork(self, *, name: str | None = None) -> "Event": + """Create a new receiver that is a clone of this receiver. + + Args: + name: An optional name for the new receiver. If None, an id-based name + will be used. + + Returns: + A new Event receiver that is a clone of this receiver. + + Raises: + ReceiverStoppedError: If this receiver is stopped. + """ + if self._is_stopped: + raise ReceiverStoppedError(self) + + return Event(name=name) + def __str__(self) -> str: """Return a string representation of this event.""" return f"{type(self).__name__}({self._name!r})" diff --git a/src/frequenz/channels/file_watcher.py b/src/frequenz/channels/file_watcher.py index e9ff4ca4..ec60ca4f 100644 --- a/src/frequenz/channels/file_watcher.py +++ b/src/frequenz/channels/file_watcher.py @@ -56,7 +56,7 @@ class Event: """The path where the change was observed.""" -class FileWatcher(Receiver[Event]): +class FileWatcher(Receiver[Event]): # pylint: disable=too-many-instance-attributes """A receiver that watches for file events. # Usage @@ -147,7 +147,8 @@ def __init__( polling is enabled. """ self.event_types: frozenset[EventType] = frozenset(event_types) - """The types of events to watch for.""" + self._force_polling: bool = force_polling + self._polling_interval: timedelta = polling_interval self._stop_event: asyncio.Event = asyncio.Event() self._paths: list[pathlib.Path] = [ @@ -250,3 +251,29 @@ def __str__(self) -> str: def __repr__(self) -> str: """Return a string representation of this receiver.""" return f"{type(self).__name__}({self._paths!r}, {self.event_types!r})" + + @override + def fork(self, *, name: str | None = None) -> "FileWatcher": + """Create a new receiver that is a clone of this receiver. + + Args: + name: An optional name for the new receiver. This is ignored since FileWatcher + receivers don't have names. + + Returns: + A new receiver that is a clone of this receiver. + + Raises: + ReceiverStoppedError: If this receiver is stopped. + """ + if self._awatch_stopped_exc is not None: + raise ReceiverStoppedError(self) + + return FileWatcher( + # list[pathlib.Path] is the correct type ( expected list[pathlib.Path | str] ) + # but mypy doesn't know that + paths=self._paths, # type: ignore + event_types=self.event_types, + force_polling=self._force_polling, + polling_interval=self._polling_interval, + ) diff --git a/src/frequenz/channels/timer.py b/src/frequenz/channels/timer.py index 73a76534..8f2ecbf8 100644 --- a/src/frequenz/channels/timer.py +++ b/src/frequenz/channels/timer.py @@ -789,3 +789,36 @@ def __repr__(self) -> str: f"{type(self).__name__}<{self.interval=}, {self.missed_tick_policy=}, " f"{self.loop=}, {self.is_running=}>" ) + + @override + def fork(self, *, name: str | None = None) -> "Timer": + """Create a new receiver that is a clone of this receiver. + + Args: + name: An optional name for the new receiver. This is ignored since Timer + receivers don't have names. + + Returns: + A new receiver that is a clone of this receiver. + + Raises: + ReceiverStoppedError: If the timer was stopped via `stop()`. + """ + if self._stopped: + raise ReceiverStoppedError(self) + + # Create a new timer with the same configuration + new_timer = Timer( + self.interval, + self.missed_tick_policy, + auto_start=self.is_running, + loop=self.loop, + ) + + # If the original timer has a next tick time set, sync the new timer + if self._next_tick_time is not None: + new_timer._next_tick_time = ( # pylint: disable=protected-access + self._next_tick_time + ) + + return new_timer diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index f995a922..666eb446 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -5,6 +5,7 @@ import asyncio +from contextlib import closing from dataclasses import dataclass from typing import TypeGuard, assert_never @@ -425,3 +426,30 @@ async def test_broadcast_close_receiver() -> None: with pytest.raises(ReceiverStoppedError): _ = await receiver_2.receive() + + +async def test_receiver_fork() -> None: + """Ensure that a receiver can be forked.""" + chan = Broadcast[int](name="input-chan") + + with ( + closing(Broadcast[int](name="input-chan")) as chan, + closing(chan.new_receiver()) as receiver, + closing(receiver.fork()) as forked_receiver, + ): + sender = chan.new_sender() + await sender.send(1) + + assert (await receiver.receive()) == 1 + assert (await forked_receiver.receive()) == 1 + + +async def test_fork_stopped_receiver() -> None: + """Ensure that a receiver can be forked.""" + chan = Broadcast[int](name="input-chan") + + receiver = chan.new_receiver() + receiver.close() + + with pytest.raises(ReceiverStoppedError): + _ = receiver.fork()