66
77from __future__ import annotations
88
9+ import bisect
910import logging
1011from abc import ABC , abstractmethod
12+ from dataclasses import dataclass
1113from datetime import datetime , timedelta , timezone
1214from typing import TYPE_CHECKING , Literal , TypedDict , overload
1315
@@ -390,13 +392,24 @@ def __eq__(self, other: object) -> bool:
390392 )
391393
392394
395+ @dataclass (frozen = True )
396+ class _CacheTier :
397+ """Configuration for a cache tier."""
398+
399+ interval : int # Cache every n'th key
400+ max_size : int | None # Maximum number of keys to cache in this tier (None = unlimited)
401+
402+
393403class _AccessoryKeyGenerator (KeyGenerator [KeyPair ]):
394404 """KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
395405
396- # cache enough keys for an entire week.
397- # every interval'th key is cached.
398- _CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
399- _CACHE_INTERVAL = 1 # cache every key
406+ # Define cache tiers: (interval, max_size)
407+ # Tier 1: Cache every 4th key (1 hour), keep up to 672 keys (2 weeks at 15min intervals)
408+ # Tier 2: Cache every 672nd key (1 week), unlimited
409+ _CACHE_TIERS = (
410+ _CacheTier (interval = 4 , max_size = 672 ),
411+ _CacheTier (interval = 672 , max_size = None ),
412+ )
400413
401414 def __init__ (
402415 self ,
@@ -422,7 +435,9 @@ def __init__(
422435 self ._initial_sk = initial_sk
423436 self ._key_type = key_type
424437
425- self ._sk_cache : dict [int , bytes ] = {}
438+ # Multi-tier cache: dict + sorted indices per tier
439+ self ._sk_caches : list [dict [int , bytes ]] = [{} for _ in self ._CACHE_TIERS ]
440+ self ._cache_indices : list [list [int ]] = [[] for _ in self ._CACHE_TIERS ]
426441
427442 self ._iter_ind = 0
428443
@@ -441,36 +456,68 @@ def key_type(self) -> KeyPairType:
441456 """The type of key this generator produces."""
442457 return self ._key_type
443458
459+ def _find_best_cached_sk (self , ind : int ) -> tuple [int , bytes ]:
460+ """Find the largest cached index smaller than ind across all tiers."""
461+ best_ind = 0
462+ best_sk = self ._initial_sk
463+
464+ for indices , cache in zip (self ._cache_indices , self ._sk_caches , strict = True ):
465+ if not indices :
466+ continue
467+
468+ # Use bisect to find the largest index < ind in O(log n)
469+ pos = bisect .bisect_left (indices , ind )
470+ if pos == 0 : # No cached index less than ind
471+ continue
472+
473+ cached_ind = indices [pos - 1 ]
474+ if cached_ind > best_ind :
475+ best_ind = cached_ind
476+ best_sk = cache [cached_ind ]
477+
478+ return best_ind , best_sk
479+
480+ def _update_caches (self , ind : int , sk : bytes ) -> None :
481+ """Update all applicable cache tiers with the computed key."""
482+ for tier_idx , tier in enumerate (self ._CACHE_TIERS ):
483+ if ind % tier .interval != 0 :
484+ continue
485+
486+ cache = self ._sk_caches [tier_idx ]
487+ indices = self ._cache_indices [tier_idx ]
488+
489+ # Add to cache if not already present
490+ if ind in cache :
491+ continue
492+ cache [ind ] = sk
493+ bisect .insort (indices , ind )
494+
495+ # Evict if cache exceeds size limit
496+ if tier .max_size is not None and len (cache ) > tier .max_size :
497+ # If adding a historical key, evict smallest index
498+ # If adding a future key, evict largest
499+ evict_ind = indices .pop (0 if indices and ind > indices [0 ] else - 1 )
500+
501+ del cache [evict_ind ]
502+
444503 def _get_sk (self , ind : int ) -> bytes :
445504 if ind < 0 :
446505 msg = "The key index must be non-negative"
447506 raise ValueError (msg )
448507
449- # retrieve from cache
450- cached_sk = self ._sk_cache .get (ind )
451- if cached_sk is not None :
452- return cached_sk
508+ # Check all caches for exact match
509+ for cache in self ._sk_caches :
510+ cached_sk = cache .get (ind )
511+ if cached_sk is not None :
512+ return cached_sk
453513
454- # not in cache: find largest cached index smaller than ind (if exists)
455- start_ind : int = 0
456- cur_sk : bytes = self ._initial_sk
457- for cached_ind in self ._sk_cache :
458- if cached_ind < ind and cached_ind > start_ind :
459- start_ind = cached_ind
460- cur_sk = self ._sk_cache [cached_ind ]
514+ # Find best starting point across all tiers
515+ start_ind , cur_sk = self ._find_best_cached_sk (ind )
461516
462- # compute and update cache
517+ # Compute from best cached position to target
463518 for cur_ind in range (start_ind + 1 , ind + 1 ):
464519 cur_sk = crypto .x963_kdf (cur_sk , b"update" , 32 )
465-
466- # insert intermediate result into cache and evict oldest entry if necessary
467- if cur_ind % self ._CACHE_INTERVAL == 0 :
468- self ._sk_cache [cur_ind ] = cur_sk
469-
470- if len (self ._sk_cache ) > self ._CACHE_SIZE :
471- # evict oldest entry
472- oldest_ind = min (self ._sk_cache .keys ())
473- del self ._sk_cache [oldest_ind ]
520+ self ._update_caches (cur_ind , cur_sk )
474521
475522 return cur_sk
476523
0 commit comments