Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 73 additions & 26 deletions findmy/accessory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from __future__ import annotations

import bisect
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Literal, TypedDict, overload

Expand Down Expand Up @@ -390,13 +392,24 @@ def __eq__(self, other: object) -> bool:
)


@dataclass(frozen=True)
class _CacheTier:
"""Configuration for a cache tier."""

interval: int # Cache every n'th key
max_size: int | None # Maximum number of keys to cache in this tier (None = unlimited)


class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""

# cache enough keys for an entire week.
# every interval'th key is cached.
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
_CACHE_INTERVAL = 1 # cache every key
# Define cache tiers: (interval, max_size)
# Tier 1: Cache every 4th key (1 hour), keep up to 672 keys (2 weeks at 15min intervals)
# Tier 2: Cache every 672nd key (1 week), unlimited
_CACHE_TIERS = (
_CacheTier(interval=4, max_size=672),
_CacheTier(interval=672, max_size=None),
)

def __init__(
self,
Expand All @@ -422,7 +435,9 @@ def __init__(
self._initial_sk = initial_sk
self._key_type = key_type

self._sk_cache: dict[int, bytes] = {}
# Multi-tier cache: dict + sorted indices per tier
self._sk_caches: list[dict[int, bytes]] = [{} for _ in self._CACHE_TIERS]
self._cache_indices: list[list[int]] = [[] for _ in self._CACHE_TIERS]

self._iter_ind = 0

Expand All @@ -441,36 +456,68 @@ def key_type(self) -> KeyPairType:
"""The type of key this generator produces."""
return self._key_type

def _find_best_cached_sk(self, ind: int) -> tuple[int, bytes]:
"""Find the largest cached index smaller than ind across all tiers."""
best_ind = 0
best_sk = self._initial_sk

for indices, cache in zip(self._cache_indices, self._sk_caches, strict=True):
if not indices:
continue

# Use bisect to find the largest index < ind in O(log n)
pos = bisect.bisect_left(indices, ind)
if pos == 0: # No cached index less than ind
continue

cached_ind = indices[pos - 1]
if cached_ind > best_ind:
best_ind = cached_ind
best_sk = cache[cached_ind]

return best_ind, best_sk

def _update_caches(self, ind: int, sk: bytes) -> None:
"""Update all applicable cache tiers with the computed key."""
for tier_idx, tier in enumerate(self._CACHE_TIERS):
if ind % tier.interval != 0:
continue

cache = self._sk_caches[tier_idx]
indices = self._cache_indices[tier_idx]

# Add to cache if not already present
if ind in cache:
continue
cache[ind] = sk
bisect.insort(indices, ind)

# Evict if cache exceeds size limit
if tier.max_size is not None and len(cache) > tier.max_size:
# If adding a historical key, evict smallest index
# If adding a future key, evict largest
evict_ind = indices.pop(0 if indices and ind > indices[0] else -1)

del cache[evict_ind]

def _get_sk(self, ind: int) -> bytes:
if ind < 0:
msg = "The key index must be non-negative"
raise ValueError(msg)

# retrieve from cache
cached_sk = self._sk_cache.get(ind)
if cached_sk is not None:
return cached_sk
# Check all caches for exact match
for cache in self._sk_caches:
cached_sk = cache.get(ind)
if cached_sk is not None:
return cached_sk

# not in cache: find largest cached index smaller than ind (if exists)
start_ind: int = 0
cur_sk: bytes = self._initial_sk
for cached_ind in self._sk_cache:
if cached_ind < ind and cached_ind > start_ind:
start_ind = cached_ind
cur_sk = self._sk_cache[cached_ind]
# Find best starting point across all tiers
start_ind, cur_sk = self._find_best_cached_sk(ind)

# compute and update cache
# Compute from best cached position to target
for cur_ind in range(start_ind + 1, ind + 1):
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)

# insert intermediate result into cache and evict oldest entry if necessary
if cur_ind % self._CACHE_INTERVAL == 0:
self._sk_cache[cur_ind] = cur_sk

if len(self._sk_cache) > self._CACHE_SIZE:
# evict oldest entry
oldest_ind = min(self._sk_cache.keys())
del self._sk_cache[oldest_ind]
self._update_caches(cur_ind, cur_sk)

return cur_sk

Expand Down