diff --git a/findmy/accessory.py b/findmy/accessory.py index 4e79da4..c58eda9 100644 --- a/findmy/accessory.py +++ b/findmy/accessory.py @@ -6,8 +6,10 @@ from __future__ import annotations +import bisect import logging from abc import ABC, abstractmethod +from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Literal, TypedDict, overload @@ -390,13 +392,24 @@ def __eq__(self, other: object) -> bool: ) +@dataclass(frozen=True) +class _CacheTier: + """Configuration for a cache tier.""" + + interval: int # Cache every n'th key + max_size: int | None # Maximum number of keys to cache in this tier (None = unlimited) + + class _AccessoryKeyGenerator(KeyGenerator[KeyPair]): """KeyPair generator. Uses the same algorithm internally as FindMy accessories do.""" - # cache enough keys for an entire week. - # every interval'th key is cached. - _CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour - _CACHE_INTERVAL = 1 # cache every key + # Define cache tiers: (interval, max_size) + # Tier 1: Cache every 4th key (1 hour), keep up to 672 keys (2 weeks at 15min intervals) + # Tier 2: Cache every 672nd key (1 week), unlimited + _CACHE_TIERS = ( + _CacheTier(interval=4, max_size=672), + _CacheTier(interval=672, max_size=None), + ) def __init__( self, @@ -422,7 +435,9 @@ def __init__( self._initial_sk = initial_sk self._key_type = key_type - self._sk_cache: dict[int, bytes] = {} + # Multi-tier cache: dict + sorted indices per tier + self._sk_caches: list[dict[int, bytes]] = [{} for _ in self._CACHE_TIERS] + self._cache_indices: list[list[int]] = [[] for _ in self._CACHE_TIERS] self._iter_ind = 0 @@ -441,36 +456,68 @@ def key_type(self) -> KeyPairType: """The type of key this generator produces.""" return self._key_type + def _find_best_cached_sk(self, ind: int) -> tuple[int, bytes]: + """Find the largest cached index smaller than ind across all tiers.""" + best_ind = 0 + best_sk = self._initial_sk + + for indices, cache in zip(self._cache_indices, self._sk_caches, strict=True): + if not indices: + continue + + # Use bisect to find the largest index < ind in O(log n) + pos = bisect.bisect_left(indices, ind) + if pos == 0: # No cached index less than ind + continue + + cached_ind = indices[pos - 1] + if cached_ind > best_ind: + best_ind = cached_ind + best_sk = cache[cached_ind] + + return best_ind, best_sk + + def _update_caches(self, ind: int, sk: bytes) -> None: + """Update all applicable cache tiers with the computed key.""" + for tier_idx, tier in enumerate(self._CACHE_TIERS): + if ind % tier.interval != 0: + continue + + cache = self._sk_caches[tier_idx] + indices = self._cache_indices[tier_idx] + + # Add to cache if not already present + if ind in cache: + continue + cache[ind] = sk + bisect.insort(indices, ind) + + # Evict if cache exceeds size limit + if tier.max_size is not None and len(cache) > tier.max_size: + # If adding a historical key, evict smallest index + # If adding a future key, evict largest + evict_ind = indices.pop(0 if indices and ind > indices[0] else -1) + + del cache[evict_ind] + def _get_sk(self, ind: int) -> bytes: if ind < 0: msg = "The key index must be non-negative" raise ValueError(msg) - # retrieve from cache - cached_sk = self._sk_cache.get(ind) - if cached_sk is not None: - return cached_sk + # Check all caches for exact match + for cache in self._sk_caches: + cached_sk = cache.get(ind) + if cached_sk is not None: + return cached_sk - # not in cache: find largest cached index smaller than ind (if exists) - start_ind: int = 0 - cur_sk: bytes = self._initial_sk - for cached_ind in self._sk_cache: - if cached_ind < ind and cached_ind > start_ind: - start_ind = cached_ind - cur_sk = self._sk_cache[cached_ind] + # Find best starting point across all tiers + start_ind, cur_sk = self._find_best_cached_sk(ind) - # compute and update cache + # Compute from best cached position to target for cur_ind in range(start_ind + 1, ind + 1): cur_sk = crypto.x963_kdf(cur_sk, b"update", 32) - - # insert intermediate result into cache and evict oldest entry if necessary - if cur_ind % self._CACHE_INTERVAL == 0: - self._sk_cache[cur_ind] = cur_sk - - if len(self._sk_cache) > self._CACHE_SIZE: - # evict oldest entry - oldest_ind = min(self._sk_cache.keys()) - del self._sk_cache[oldest_ind] + self._update_caches(cur_ind, cur_sk) return cur_sk