diff --git a/chia/_tests/core/util/test_lru_cache.py b/chia/_tests/core/util/test_lru_cache.py index 905529cc23b4..fb68f585ce6d 100644 --- a/chia/_tests/core/util/test_lru_cache.py +++ b/chia/_tests/core/util/test_lru_cache.py @@ -2,6 +2,8 @@ import unittest +import pytest + from chia.util.lru_cache import LRUCache @@ -54,3 +56,17 @@ def test_lru_cache(self): assert len(cache.cache) == 5 assert cache.get(b"0") is None assert cache.get(b"1") == 1 + + +@pytest.mark.parametrize(argnames="capacity", argvalues=[-10, -1, 0]) +def test_with_zero_capacity(capacity: int) -> None: + cache: LRUCache[bytes, int] = LRUCache(capacity=capacity) + cache.put(b"0", 1) + assert cache.get(b"0") is None + assert len(cache.cache) == 0 + + +@pytest.mark.parametrize(argnames="capacity", argvalues=[-10, -1, 0, 1, 5, 10]) +def test_get_capacity(capacity: int) -> None: + cache: LRUCache[object, object] = LRUCache(capacity=capacity) + assert cache.get_capacity() == capacity diff --git a/chia/util/lru_cache.py b/chia/util/lru_cache.py index 143df0dcfce2..a5032366c6b3 100644 --- a/chia/util/lru_cache.py +++ b/chia/util/lru_cache.py @@ -22,10 +22,14 @@ def get(self, key: K) -> Optional[V]: return self.cache[key] def put(self, key: K, value: V) -> None: - self.cache[key] = value - self.cache.move_to_end(key) - if len(self.cache) > self.capacity: - self.cache.popitem(last=False) + if self.capacity > 0: + self.cache[key] = value + self.cache.move_to_end(key) + if len(self.cache) > self.capacity: + self.cache.popitem(last=False) def remove(self, key: K) -> None: self.cache.pop(key) + + def get_capacity(self) -> int: + return self.capacity