diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 96a0240b..52636b7c 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,7 +10,7 @@ ## New Features - +- `LatestValueCache` now takes an optional `key` function, which returns the key for each incoming message, and the latest value for each key is cached and can be retrieved separately. ## Bug Fixes diff --git a/src/frequenz/channels/_latest_value_cache.py b/src/frequenz/channels/_latest_value_cache.py index a9af9436..fa04355b 100644 --- a/src/frequenz/channels/_latest_value_cache.py +++ b/src/frequenz/channels/_latest_value_cache.py @@ -8,12 +8,20 @@ [LatestValueCache][frequenz.channels.LatestValueCache] takes a [Receiver][frequenz.channels.Receiver] as an argument and stores the latest -value received by that receiver. As soon as a value is received, its +value received by that receiver. It also takes an optional `key` function +that allows you to group the values by a specific key. If the `key` is +provided, the cache will store the latest value for each key separately, +otherwise it will store only the latest value received overall. + +As soon as a value is received, its [`has_value`][frequenz.channels.LatestValueCache.has_value] method returns `True`, and its [`get`][frequenz.channels.LatestValueCache.get] method returns the latest value received. The `get` method will raise an exception if called before any messages have been received from the receiver. +Both `has_value` and `get` methods can take an optional `key` argument to +check or retrieve the latest value for that specific key. + Example: ```python from frequenz.channels import Broadcast, LatestValueCache @@ -32,31 +40,84 @@ ``` """ +from __future__ import annotations + import asyncio import typing from ._receiver import Receiver T_co = typing.TypeVar("T_co", covariant=True) +HashableT = typing.TypeVar("HashableT", bound=typing.Hashable) -class _Sentinel: +class Sentinel: """A sentinel to denote that no value has been received yet.""" + def __init__(self, desc: str) -> None: + """Initialize the sentinel.""" + self._desc = desc + def __str__(self) -> str: """Return a string representation of this sentinel.""" - return "" + return f"" -class LatestValueCache(typing.Generic[T_co]): +NO_KEY: typing.Final[Sentinel] = Sentinel("no key provided") +NO_KEY_FUNCTION: typing.Final[Sentinel] = Sentinel("no key function provided") +NO_VALUE_RECEIVED: typing.Final[Sentinel] = Sentinel("no value received yet") + + +class LatestValueCache(typing.Generic[T_co, HashableT]): """A cache that stores the latest value in a receiver. It provides a way to look up the latest value in a stream without any delay, as long as there has been one value received. """ + @typing.overload + def __init__( + self: LatestValueCache[T_co, Sentinel], + receiver: Receiver[T_co], + *, + unique_id: str | None = None, + key: Sentinel = NO_KEY_FUNCTION, + ) -> None: + """Create a new cache that does not use keys. + + Args: + receiver: The receiver to cache. + unique_id: A string to help uniquely identify this instance. If not + provided, a unique identifier will be generated from the object's + [`id()`][id]. It is used mostly for debugging purposes. + key: This parameter is ignored when set to `None`. + """ + + @typing.overload def __init__( - self, receiver: Receiver[T_co], *, unique_id: str | None = None + self: LatestValueCache[T_co, HashableT], + receiver: Receiver[T_co], + *, + unique_id: str | None = None, + key: typing.Callable[[T_co], HashableT], + ) -> None: + """Create a new cache that uses keys. + + Args: + receiver: The receiver to cache. + unique_id: A string to help uniquely identify this instance. If not + provided, a unique identifier will be generated from the object's + [`id()`][id]. It is used mostly for debugging purposes. + key: A function that takes a value and returns a key to group the values by. + If provided, the cache will store the latest value for each key separately. + """ + + def __init__( + self, + receiver: Receiver[T_co], + *, + unique_id: str | None = None, + key: typing.Callable[[T_co], typing.Any] | Sentinel = NO_KEY_FUNCTION, ) -> None: """Create a new cache. @@ -65,10 +126,16 @@ def __init__( unique_id: A string to help uniquely identify this instance. If not provided, a unique identifier will be generated from the object's [`id()`][id]. It is used mostly for debugging purposes. + key: An optional function that takes a value and returns a key to group the + values by. If provided, the cache will store the latest value for each + key separately. If not provided, it will store only the latest value + received overall. """ self._receiver = receiver + self._key: typing.Callable[[T_co], HashableT] | Sentinel = key self._unique_id: str = hex(id(self)) if unique_id is None else unique_id - self._latest_value: T_co | _Sentinel = _Sentinel() + self._latest_value: T_co | Sentinel = NO_VALUE_RECEIVED + self._latest_value_by_key: dict[HashableT, T_co] = {} self._task = asyncio.create_task( self._run(), name=f"LatestValueCache«{self._unique_id}»" ) @@ -78,34 +145,53 @@ def unique_id(self) -> str: """The unique identifier of this instance.""" return self._unique_id - def get(self) -> T_co: + def get(self, key: HashableT | Sentinel = NO_KEY) -> T_co: """Return the latest value that has been received. This raises a `ValueError` if no value has been received yet. Use `has_value` to check whether a value has been received yet, before trying to access the value, to avoid the exception. + Args: + key: An optional key to retrieve the latest value for that key. If not + provided, it retrieves the latest value received overall. + Returns: The latest value that has been received. Raises: ValueError: If no value has been received yet. """ - if isinstance(self._latest_value, _Sentinel): + if not isinstance(key, Sentinel): + if key not in self._latest_value_by_key: + raise ValueError(f"No value received for key: {key!r}") + return self._latest_value_by_key[key] + + if isinstance(self._latest_value, Sentinel): raise ValueError("No value has been received yet.") return self._latest_value - def has_value(self) -> bool: + def has_value(self, key: HashableT | Sentinel = NO_KEY) -> bool: """Check whether a value has been received yet. + If `key` is provided, it checks whether a value has been received for that key. + + Args: + key: An optional key to check if a value has been received for that key. + Returns: `True` if a value has been received, `False` otherwise. """ - return not isinstance(self._latest_value, _Sentinel) + if not isinstance(key, Sentinel): + return key in self._latest_value_by_key + return not isinstance(self._latest_value, Sentinel) async def _run(self) -> None: async for value in self._receiver: self._latest_value = value + if not isinstance(self._key, Sentinel): + key = self._key(value) + self._latest_value_by_key[key] = value async def stop(self) -> None: """Stop the cache.""" diff --git a/tests/test_latest_value_cache_integration.py b/tests/test_latest_value_cache_integration.py index 0c39b21e..97020c42 100644 --- a/tests/test_latest_value_cache_integration.py +++ b/tests/test_latest_value_cache_integration.py @@ -43,3 +43,52 @@ async def test_latest_value_cache() -> None: await asyncio.sleep(0) assert cache.get() == 19 + + +@pytest.mark.integration +async def test_latest_value_cache_key() -> None: + """Ensure LatestValueCache works with keys.""" + channel = Broadcast[tuple[int, str]](name="lvc_test") + + cache = LatestValueCache(channel.new_receiver(), key=lambda x: x[0]) + sender = channel.new_sender() + + assert not cache.has_value() + with pytest.raises(ValueError, match="No value has been received yet."): + cache.get() + with pytest.raises(ValueError, match="No value received for key: 0"): + cache.get(0) + + await sender.send((5, "a")) + await sender.send((6, "b")) + await sender.send((5, "c")) + await asyncio.sleep(0) + + assert cache.has_value() + assert cache.has_value(5) + assert cache.has_value(6) + assert not cache.has_value(7) + + assert cache.get() == (5, "c") + assert cache.get(5) == (5, "c") + assert cache.get(6) == (6, "b") + + with pytest.raises(ValueError, match="No value received for key: 7"): + cache.get(7) + + await sender.send((12, "d")) + await asyncio.sleep(0) + + assert cache.get() == (12, "d") + assert cache.get() == (12, "d") + assert cache.get(12) == (12, "d") + assert cache.get(12) == (12, "d") + assert cache.get(5) == (5, "c") + assert cache.get(6) == (6, "b") + + await sender.send((6, "e")) + await sender.send((6, "f")) + await sender.send((6, "g")) + await asyncio.sleep(0) + + assert cache.get(6) == (6, "g")