Skip to content

Commit b3d4a71

Browse files
authored
Merge pull request #216 from malmeloo/feat/better-key-cache
fix: more efficient key caching system
2 parents 485b984 + c9897ea commit b3d4a71

File tree

1 file changed

+73
-26
lines changed

1 file changed

+73
-26
lines changed

findmy/accessory.py

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
from __future__ import annotations
88

9+
import bisect
910
import logging
1011
from abc import ABC, abstractmethod
12+
from dataclasses import dataclass
1113
from datetime import datetime, timedelta, timezone
1214
from typing import TYPE_CHECKING, Literal, TypedDict, overload
1315

@@ -390,13 +392,24 @@ def __eq__(self, other: object) -> bool:
390392
)
391393

392394

395+
@dataclass(frozen=True)
396+
class _CacheTier:
397+
"""Configuration for a cache tier."""
398+
399+
interval: int # Cache every n'th key
400+
max_size: int | None # Maximum number of keys to cache in this tier (None = unlimited)
401+
402+
393403
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
394404
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
395405

396-
# cache enough keys for an entire week.
397-
# every interval'th key is cached.
398-
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
399-
_CACHE_INTERVAL = 1 # cache every key
406+
# Define cache tiers: (interval, max_size)
407+
# Tier 1: Cache every 4th key (1 hour), keep up to 672 keys (2 weeks at 15min intervals)
408+
# Tier 2: Cache every 672nd key (1 week), unlimited
409+
_CACHE_TIERS = (
410+
_CacheTier(interval=4, max_size=672),
411+
_CacheTier(interval=672, max_size=None),
412+
)
400413

401414
def __init__(
402415
self,
@@ -422,7 +435,9 @@ def __init__(
422435
self._initial_sk = initial_sk
423436
self._key_type = key_type
424437

425-
self._sk_cache: dict[int, bytes] = {}
438+
# Multi-tier cache: dict + sorted indices per tier
439+
self._sk_caches: list[dict[int, bytes]] = [{} for _ in self._CACHE_TIERS]
440+
self._cache_indices: list[list[int]] = [[] for _ in self._CACHE_TIERS]
426441

427442
self._iter_ind = 0
428443

@@ -441,36 +456,68 @@ def key_type(self) -> KeyPairType:
441456
"""The type of key this generator produces."""
442457
return self._key_type
443458

459+
def _find_best_cached_sk(self, ind: int) -> tuple[int, bytes]:
460+
"""Find the largest cached index smaller than ind across all tiers."""
461+
best_ind = 0
462+
best_sk = self._initial_sk
463+
464+
for indices, cache in zip(self._cache_indices, self._sk_caches, strict=True):
465+
if not indices:
466+
continue
467+
468+
# Use bisect to find the largest index < ind in O(log n)
469+
pos = bisect.bisect_left(indices, ind)
470+
if pos == 0: # No cached index less than ind
471+
continue
472+
473+
cached_ind = indices[pos - 1]
474+
if cached_ind > best_ind:
475+
best_ind = cached_ind
476+
best_sk = cache[cached_ind]
477+
478+
return best_ind, best_sk
479+
480+
def _update_caches(self, ind: int, sk: bytes) -> None:
481+
"""Update all applicable cache tiers with the computed key."""
482+
for tier_idx, tier in enumerate(self._CACHE_TIERS):
483+
if ind % tier.interval != 0:
484+
continue
485+
486+
cache = self._sk_caches[tier_idx]
487+
indices = self._cache_indices[tier_idx]
488+
489+
# Add to cache if not already present
490+
if ind in cache:
491+
continue
492+
cache[ind] = sk
493+
bisect.insort(indices, ind)
494+
495+
# Evict if cache exceeds size limit
496+
if tier.max_size is not None and len(cache) > tier.max_size:
497+
# If adding a historical key, evict smallest index
498+
# If adding a future key, evict largest
499+
evict_ind = indices.pop(0 if indices and ind > indices[0] else -1)
500+
501+
del cache[evict_ind]
502+
444503
def _get_sk(self, ind: int) -> bytes:
445504
if ind < 0:
446505
msg = "The key index must be non-negative"
447506
raise ValueError(msg)
448507

449-
# retrieve from cache
450-
cached_sk = self._sk_cache.get(ind)
451-
if cached_sk is not None:
452-
return cached_sk
508+
# Check all caches for exact match
509+
for cache in self._sk_caches:
510+
cached_sk = cache.get(ind)
511+
if cached_sk is not None:
512+
return cached_sk
453513

454-
# not in cache: find largest cached index smaller than ind (if exists)
455-
start_ind: int = 0
456-
cur_sk: bytes = self._initial_sk
457-
for cached_ind in self._sk_cache:
458-
if cached_ind < ind and cached_ind > start_ind:
459-
start_ind = cached_ind
460-
cur_sk = self._sk_cache[cached_ind]
514+
# Find best starting point across all tiers
515+
start_ind, cur_sk = self._find_best_cached_sk(ind)
461516

462-
# compute and update cache
517+
# Compute from best cached position to target
463518
for cur_ind in range(start_ind + 1, ind + 1):
464519
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)
465-
466-
# insert intermediate result into cache and evict oldest entry if necessary
467-
if cur_ind % self._CACHE_INTERVAL == 0:
468-
self._sk_cache[cur_ind] = cur_sk
469-
470-
if len(self._sk_cache) > self._CACHE_SIZE:
471-
# evict oldest entry
472-
oldest_ind = min(self._sk_cache.keys())
473-
del self._sk_cache[oldest_ind]
520+
self._update_caches(cur_ind, cur_sk)
474521

475522
return cur_sk
476523

0 commit comments

Comments
 (0)