Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backend/concept_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
)
from .index import get_index
from .models import Facet, QueryModel, ResolvedMention
from .pipeline import run_pipeline
from .pipeline import pipeline_cache, run_pipeline
from .rate_limit import RateLimiter
from .resolve_agent import resolve_cache
from .store import DuckDBStore

# Structured JSON logging to stdout (picked up by CloudWatch via App Runner)
Expand Down Expand Up @@ -367,6 +368,8 @@ async def health() -> dict:
index = get_index()
return {
"indexStats": index.stats,
"pipelineCache": pipeline_cache.stats,
"resolveCache": resolve_cache.stats,
"status": "ok",
}

Expand Down
145 changes: 145 additions & 0 deletions backend/concept_search/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Generic async LRU cache with TTL and in-flight deduplication."""

from __future__ import annotations

import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Generic, TypeVar

logger = logging.getLogger(__name__)

K = TypeVar("K")
V = TypeVar("V")

# Registry of all cache instances for bulk clear
_all_caches: list[LRUCache] = [] # type: ignore[type-arg]


@dataclass
class _CacheEntry(Generic[V]):
"""A cached value with creation timestamp."""

created: float
value: V


@dataclass
class LRUCache(Generic[K, V]):
"""Async LRU cache with TTL and in-flight deduplication.

- Entries expire after ``ttl_seconds``.
- When ``max_size`` is reached the oldest entry is evicted.
- Concurrent calls for the same key share a single computation.

All instances are registered for bulk clearing via ``clear_all()``.
"""

name: str
hits: int = 0
max_size: int = 10_000
misses: int = 0
ttl_seconds: float = 86400.0
_cache: dict[K, _CacheEntry[V]] = field(default_factory=dict)
_in_flight: dict[K, asyncio.Event] = field(default_factory=dict)
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)

def __post_init__(self) -> None:
"""Register this cache instance for bulk clearing."""
_all_caches.append(self)

async def get_or_compute(
self,
key: K,
compute: asyncio.coroutines,
) -> V:
Comment on lines 55 to 57
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compute parameter type annotation is incorrect: asyncio.coroutines is not a useful callable/awaitable type here and won’t type-check. Consider typing this as an async callable (e.g., Callable[[], Awaitable[V]]) so callers and static type checkers have the right contract.

Copilot uses AI. Check for mistakes.
"""Return a cached value or compute it.

Args:
key: The cache key (must be hashable).
compute: An async callable that produces the value on cache miss.

Returns:
Cached or freshly-computed value.
"""
async with self._lock:
entry = self._cache.get(key)
if entry and (time.monotonic() - entry.created) < self.ttl_seconds:
self.hits += 1
self._cache[key] = self._cache.pop(key)
logger.info("%s hit key=%s", self.name, key)
return entry.value
Comment on lines 70 to 73
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.info("%s hit key=%s", ...) logs the full cache key, which for pipeline_cache/resolve_cache includes user-provided query/mention text. Since logs go to CloudWatch at INFO level, this can leak potentially sensitive user input and will also be very noisy in a hot path. Consider removing this log, downgrading to DEBUG, and/or logging only non-sensitive metadata (e.g., cache name + a hash of the key).

Copilot uses AI. Check for mistakes.

event = self._in_flight.get(key)
if event is not None:
pass # fall through to await below
else:
event = asyncio.Event()
self._in_flight[key] = event
event = None # signal that we are the owner

if event is not None:
await event.wait()
async with self._lock:
entry = self._cache.get(key)
if entry:
self.hits += 1
self._cache[key] = self._cache.pop(key)
return entry.value

self.misses += 1
logger.info("%s miss key=%s", self.name, key)
try:
Comment on lines 92 to 95
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.info("%s miss key=%s", ...) has the same issue as the hit log: it records the raw cache key (including user query/mention text) at INFO level, which is both sensitive and high-volume. Consider removing it, switching to DEBUG, and/or logging only anonymized key info.

Copilot uses AI. Check for mistakes.
value = await compute()
finally:
async with self._lock:
ev = self._in_flight.pop(key, None)
if ev is not None:
ev.set()

async with self._lock:
if len(self._cache) >= self.max_size:
oldest = next(iter(self._cache))
del self._cache[oldest]
self._cache[key] = _CacheEntry(
created=time.monotonic(), value=value
)

Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In-flight deduplication is currently racy: the in-flight Event is set in the finally block before the computed value is inserted into _cache. Waiters can wake up, not find the entry yet, and then proceed to compute again (breaking the “run only once” guarantee and potentially causing stampedes). Consider storing the result (or exception) and updating _cache while still holding the lock, and only then setting the event (or use an asyncio.Future/Task per key to propagate result/exception).

Suggested change
try:
value = await compute()
finally:
async with self._lock:
ev = self._in_flight.pop(key, None)
if ev is not None:
ev.set()
async with self._lock:
if len(self._cache) >= self.max_size:
oldest = next(iter(self._cache))
del self._cache[oldest]
self._cache[key] = _CacheEntry(
created=time.monotonic(), value=value
)
success = False
try:
value = await compute()
success = True
finally:
async with self._lock:
if success:
if len(self._cache) >= self.max_size:
oldest = next(iter(self._cache))
del self._cache[oldest]
self._cache[key] = _CacheEntry(
created=time.monotonic(), value=value
)
ev = self._in_flight.pop(key, None)
if ev is not None:
ev.set()

Copilot uses AI. Check for mistakes.
return value

async def clear(self) -> int:
"""Remove all cached entries and reset counters.

Returns:
Number of entries that were cleared.
"""
async with self._lock:
n = len(self._cache)
self._cache.clear()
self.hits = 0
self.misses = 0
return n

@property
def stats(self) -> dict:
"""Return cache statistics."""
total = self.hits + self.misses
return {
"hit_rate": round(self.hits / total, 3) if total else 0,
"hits": self.hits,
"misses": self.misses,
"size": len(self._cache),
}
Comment on lines +126 to +135
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stats property reads self.hits, self.misses, and len(self._cache) without acquiring self._lock. This could lead to race conditions where the stats are inconsistent or stale.

For example, if one coroutine is updating the cache while another calls stats, the hit rate calculation could be based on mismatched values of hits and misses, or the size could be inconsistent with the hit/miss counts.

Consider either:

  1. Acquiring the lock in the stats property: async with self._lock: return {...}
  2. Documenting that stats are eventually consistent and may be slightly stale
  3. Using atomic operations or a separate lock for stats if performance is a concern

Note that making this an async property would require changing all call sites to await it.

Copilot uses AI. Check for mistakes.


async def clear_all() -> dict[str, int]:
"""Clear all registered cache instances.

Returns:
Dict mapping cache name to number of entries cleared.
"""
results = {}
for cache in _all_caches:
results[cache.name] = await cache.clear()
return results
Loading