diff --git a/backend/concept_search/api.py b/backend/concept_search/api.py index 7f96936..3197094 100644 --- a/backend/concept_search/api.py +++ b/backend/concept_search/api.py @@ -29,8 +29,9 @@ ) from .index import get_index from .models import Facet, QueryModel, ResolvedMention -from .pipeline import run_pipeline +from .pipeline import pipeline_cache, run_pipeline from .rate_limit import RateLimiter +from .resolve_agent import resolve_cache from .store import DuckDBStore # Structured JSON logging to stdout (picked up by CloudWatch via App Runner) @@ -367,6 +368,8 @@ async def health() -> dict: index = get_index() return { "indexStats": index.stats, + "pipelineCache": pipeline_cache.stats, + "resolveCache": resolve_cache.stats, "status": "ok", } diff --git a/backend/concept_search/cache.py b/backend/concept_search/cache.py new file mode 100644 index 0000000..d902914 --- /dev/null +++ b/backend/concept_search/cache.py @@ -0,0 +1,147 @@ +"""Generic async LRU cache with TTL and in-flight deduplication.""" + +from __future__ import annotations + +import asyncio +import logging +import time +from dataclasses import dataclass, field +from collections.abc import Awaitable, Callable +from typing import Generic, TypeVar + +logger = logging.getLogger(__name__) + +K = TypeVar("K") +V = TypeVar("V") + +# Registry of all cache instances for bulk clear +_all_caches: list[LRUCache] = [] # type: ignore[type-arg] + + +@dataclass +class _CacheEntry(Generic[V]): + """A cached value with creation timestamp.""" + + created: float + value: V + + +@dataclass +class LRUCache(Generic[K, V]): + """Async LRU cache with TTL and in-flight deduplication. + + - Entries expire after ``ttl_seconds``. + - When ``max_size`` is reached the oldest entry is evicted. + - Concurrent calls for the same key share a single computation. + + All instances are registered for bulk clearing via ``clear_all()``. + """ + + name: str + hits: int = 0 + max_size: int = 10_000 + misses: int = 0 + ttl_seconds: float = 86400.0 + _cache: dict[K, _CacheEntry[V]] = field(default_factory=dict) + _in_flight: dict[K, asyncio.Event] = field(default_factory=dict) + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + def __post_init__(self) -> None: + """Register this cache instance for bulk clearing.""" + _all_caches.append(self) + + async def get_or_compute( + self, + key: K, + compute: Callable[[], Awaitable[V]], + ) -> V: + """Return a cached value or compute it. + + Args: + key: The cache key (must be hashable). + compute: An async callable that produces the value on cache miss. + + Returns: + Cached or freshly-computed value. + """ + async with self._lock: + entry = self._cache.get(key) + if entry and (time.monotonic() - entry.created) < self.ttl_seconds: + self.hits += 1 + self._cache[key] = self._cache.pop(key) + logger.debug("%s hit key=%s", self.name, key) + return entry.value + + event = self._in_flight.get(key) + if event is not None: + pass # fall through to await below + else: + event = asyncio.Event() + self._in_flight[key] = event + event = None # signal that we are the owner + + if event is not None: + await event.wait() + async with self._lock: + entry = self._cache.get(key) + if entry and (time.monotonic() - entry.created) < self.ttl_seconds: + self.hits += 1 + self._cache[key] = self._cache.pop(key) + return entry.value + + self.misses += 1 + logger.debug("%s miss key=%s", self.name, key) + success = False + try: + value = await compute() + success = True + finally: + async with self._lock: + if success: + if len(self._cache) >= self.max_size: + oldest = next(iter(self._cache)) + del self._cache[oldest] + self._cache[key] = _CacheEntry( + created=time.monotonic(), value=value + ) + ev = self._in_flight.pop(key, None) + if ev is not None: + ev.set() + + return value + + async def clear(self) -> int: + """Remove all cached entries and reset counters. + + Returns: + Number of entries that were cleared. + """ + async with self._lock: + n = len(self._cache) + self._cache.clear() + self.hits = 0 + self.misses = 0 + return n + + @property + def stats(self) -> dict: + """Return cache statistics.""" + total = self.hits + self.misses + return { + "hit_rate": round(self.hits / total, 3) if total else 0, + "hits": self.hits, + "misses": self.misses, + "size": len(self._cache), + } + + +async def clear_all() -> dict[str, int]: + """Clear all registered cache instances. + + Returns: + Dict mapping cache name to number of entries cleared. + """ + results = {} + for cache in _all_caches: + results[cache.name] = await cache.clear() + return results diff --git a/backend/concept_search/eval_resolve.py b/backend/concept_search/eval_resolve.py index 67093f3..0fd5e95 100644 --- a/backend/concept_search/eval_resolve.py +++ b/backend/concept_search/eval_resolve.py @@ -9,19 +9,20 @@ from pydantic_evals import Case, Dataset from pydantic_evals.evaluators import Evaluator, EvaluatorContext +from .consent_logic import compute_eligible_codes, resolve_disease_name from .index import get_index from .models import Facet, RawMention, ResolveResult from .resolve_agent import run_resolve class ResolveEvaluator(Evaluator[RawMention, ResolveResult]): - """Scores resolve agent output using recall on expected values. + """Scores resolve agent output using F1 on expected values. Matching logic: - Values are compared case-insensitively. - - Recall: all expected values must appear in actual. - - Extra values in actual are not penalized. - - Score 1.0 if expected has no values (just checks agent returns something). + - F1: harmonic mean of precision and recall. Penalizes both + missing expected values and spurious extra values. + - Score 1.0 if expected has no values and actual is also empty. """ def evaluate( @@ -30,7 +31,6 @@ def evaluate( expected = ctx.expected_output actual = ctx.output if expected is None or not expected.values: - # If we expect empty, score 1.0 if actual is also empty return { "resolve_score": 1.0 if not actual.values else 0.0 } @@ -39,8 +39,13 @@ def evaluate( act_set = {v.lower() for v in actual.values} if not act_set: return {"resolve_score": 0.0} - hits = exp_set & act_set - return {"resolve_score": round(len(hits) / len(exp_set), 3)} + hits = len(exp_set & act_set) + precision = hits / len(act_set) + recall = hits / len(exp_set) + if precision + recall == 0: + return {"resolve_score": 0.0} + f1 = 2 * precision * recall / (precision + recall) + return {"resolve_score": round(f1, 3)} def _mention(text: str, facet: Facet) -> RawMention: @@ -48,6 +53,33 @@ def _mention(text: str, facet: Facet) -> RawMention: return RawMention(facet=facet, text=text, values=[]) +# --------------------------------------------------------------------------- +# Dynamic consent code expectations +# --------------------------------------------------------------------------- + +def _consent_expected(**kwargs: object) -> ResolveResult: + """Compute expected consent code values deterministically. + + Loads the index once, gets all consent codes, and calls + ``compute_eligible_codes`` with the given kwargs. This keeps + expectations in sync with the actual catalog data. + + Args: + **kwargs: Forwarded to ``compute_eligible_codes`` (after resolving + any ``disease`` name to an abbreviation). + + Returns: + A :class:`ResolveResult` with sorted eligible codes. + """ + index = get_index() + all_codes = [m.value for m in index.list_facet_values("consentCode")] + # Resolve disease name if provided + if "disease" in kwargs: + kwargs["disease"] = resolve_disease_name(str(kwargs["disease"])) + eligible = compute_eligible_codes(all_codes, **kwargs) # type: ignore[arg-type] + return ResolveResult(values=sorted(eligible)) + + dataset = Dataset[RawMention, ResolveResult, ResolveResult]( evaluators=[ResolveEvaluator()], cases=[ @@ -62,29 +94,23 @@ def _mention(text: str, facet: Facet) -> RawMention: inputs=_mention("systolic blood pressure", Facet.MEASUREMENT), expected_output=ResolveResult(values=["Systolic Blood Pressure"]), ), - Case( - name="direct-diabetes-focus", - inputs=_mention("diabetes", Facet.FOCUS), - expected_output=ResolveResult(values=["Diabetes Mellitus"]), - ), - Case( - name="direct-consent-gru", - inputs=_mention("GRU", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["GRU"]), - ), + # (diabetes/focus is tested as focus-diabetes below) + # (GRU direct match is tested as consent-gru-direct below) # --- Lay term rewrites --- Case( name="lay-blood-sugar", inputs=_mention("blood sugar", Facet.MEASUREMENT), # Agent should search "blood sugar", find low-count results, - # then rewrite to "glucose" and find Fasting Glucose (44 studies). - expected_output=ResolveResult(values=["Fasting Glucose"]), + # then rewrite to "glucose" and find glucose-related concepts. + expected_output=ResolveResult( + values=["Blood Glucose", "Fasting Glucose", "Glucose", "Serum Glucose"] + ), ), Case( name="lay-blood-pressure", inputs=_mention("blood pressure", Facet.MEASUREMENT), expected_output=ResolveResult( - values=["Systolic Blood Pressure"] + values=["Diastolic Blood Pressure", "Systolic Blood Pressure"] ), ), # --- Category expansion --- @@ -92,24 +118,53 @@ def _mention(text: str, facet: Facet) -> RawMention: name="category-sleep", inputs=_mention("sleep", Facet.MEASUREMENT), # "sleep" is broad — agent should return multiple sleep concepts. - # Accept any result that includes Sleep Duration. - expected_output=ResolveResult(values=["Sleep Duration"]), + expected_output=ResolveResult( + values=[ + "Daytime Sleepiness", + "Epworth Sleepiness Scale", + "Excessive Daytime Sleepiness", + "Obstructive Sleep Apnea History", + "Oxygen Therapy Use During Sleep", + "Sleep Apnea History", + "Sleep Disorder History", + "Sleep Disturbance", + "Sleep Duration", + "Sleep Efficiency", + "Sleep Latency", + "Sleep Maintenance Insomnia", + "Sleep Medication Use", + "Sleep Onset Difficulty", + "Sleep Onset Insomnia", + "Sleep Onset Latency", + "Sleep Problems", + "Sleep Quality", + "Total Sleep Time", + "Wake After Sleep Onset", + ] + ), ), Case( name="category-cholesterol", inputs=_mention("cholesterol", Facet.MEASUREMENT), - # "cholesterol" is ambiguous — Total, HDL, LDL, Dietary, etc. + # "cholesterol" is ambiguous — Total, HDL, LDL, Triglycerides. # Agent should return broad match and disambiguate via message. - expected_output=ResolveResult(values=["Total Cholesterol"]), + expected_output=ResolveResult( + values=["HDL Cholesterol", "LDL Cholesterol", "Total Cholesterol", "Triglycerides"] + ), ), Case( name="disambig-glucose", inputs=_mention("glucose", Facet.MEASUREMENT), - # "glucose" spans 5 categories: Endocrine (Fasting Glucose, 44), - # Lab Tests (Glucose, 38), Dietary (Glucose Intake, 5), etc. - # Agent should pick Fasting Glucose (highest count) and - # disambiguate — importantly NOT an example in the prompt. - expected_output=ResolveResult(values=["Fasting Glucose"]), + # "glucose" spans multiple categories — agent should return + # the main glucose-related measurement concepts. + expected_output=ResolveResult( + values=[ + "2-Hour Plasma Glucose", + "Fasting Blood Glucose", + "Fasting Glucose", + "Postprandial Blood Glucose", + ] + ), ), # --- Medical synonym --- Case( @@ -120,21 +175,26 @@ def _mention(text: str, facet: Facet) -> RawMention: Case( name="synonym-smoking", inputs=_mention("smoking", Facet.MEASUREMENT), - expected_output=ResolveResult(values=["Smoking Status"]), + expected_output=ResolveResult( + values=["Current Smoking Status", "Smoking History", "Smoking Status"] + ), ), # --- Harder rewrites --- Case( name="rewrite-vitamin-k", inputs=_mention("vitamin K", Facet.MEASUREMENT), - expected_output=ResolveResult(values=["Vitamin K Intake"]), + expected_output=ResolveResult( + values=["Vitamin K Intake", "Vitamin K Supplementation"] + ), ), Case( name="rewrite-heart-disease", inputs=_mention("heart disease", Facet.FOCUS), # With category drill-down, agent sees full list and picks - # the broader "Cardiovascular Diseases" (81 studies) over - # "Heart Diseases" (4 studies). Both are valid. - expected_output=ResolveResult(values=["Cardiovascular Diseases"]), + # both broad and specific heart disease terms. + expected_output=ResolveResult( + values=["Cardiovascular Diseases", "Heart Diseases"] + ), ), # --- Focus/disease via category drill-down --- Case( @@ -150,7 +210,14 @@ def _mention(text: str, facet: Facet) -> RawMention: Case( name="focus-lung-cancer", inputs=_mention("lung cancer", Facet.FOCUS), - expected_output=ResolveResult(values=["Lung Neoplasms"]), + expected_output=ResolveResult( + values=[ + "Adenocarcinoma of Lung", + "Carcinoma, Non-Small-Cell Lung", + "Lung Neoplasms", + "Small Cell Lung Carcinoma", + ] + ), ), Case( name="focus-diabetes", @@ -209,127 +276,108 @@ def _mention(text: str, facet: Facet) -> RawMention: inputs=_mention("schizophrenia", Facet.FOCUS), expected_output=ResolveResult(values=["Schizophrenia"]), ), - # --- Consent code semantic resolution --- + # --- Consent code semantic resolution (dynamic expectations) --- Case( name="consent-gru-direct", inputs=_mention("GRU", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["GRU"]), + expected_output=_consent_expected(explicit_code="GRU"), ), Case( name="consent-general-research", inputs=_mention("general research use", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["GRU"]), + expected_output=_consent_expected(explicit_code="GRU"), ), Case( name="consent-hmb-direct", inputs=_mention("HMB", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["HMB"]), + expected_output=_consent_expected(explicit_code="HMB"), ), Case( name="consent-health-medical", inputs=_mention("health medical biomedical", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["HMB"]), + expected_output=_consent_expected(purpose="health"), ), Case( name="consent-disease-specific-cvd", inputs=_mention("cardiovascular disease specific", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["DS-CVD"]), + expected_output=_consent_expected(explicit_code="DS-CVD"), ), Case( name="consent-breast-cancer", inputs=_mention("breast cancer research", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["DS-BRCA"]), + expected_output=_consent_expected(purpose="disease", disease="BRCA"), ), Case( name="consent-not-for-profit", inputs=_mention("general research, not for profit", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["GRU-NPU"]), + expected_output=_consent_expected(purpose="general"), ), Case( name="consent-hmb-irb", inputs=_mention("HMB-IRB", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["HMB-IRB"]), + expected_output=_consent_expected(explicit_code="HMB-IRB"), ), Case( name="consent-diabetes", inputs=_mention("diabetes research", Facet.CONSENT_CODE), - # Eligibility: should return DS-DIAB-* codes (via compute_consent_eligibility) - expected_output=ResolveResult(values=["DS-DIAB-NPU"]), + expected_output=_consent_expected(purpose="disease", disease="DIAB"), ), # --- Consent eligibility resolution --- Case( name="consent-for-profit-cancer", inputs=_mention("for-profit cancer", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["DS-CA"]), - ), - Case( - name="consent-explicit-gru", - inputs=_mention("GRU", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["GRU", "GRU-IRB"]), - ), - Case( - name="consent-explicit-hmb", - inputs=_mention("HMB", Facet.CONSENT_CODE), - expected_output=ResolveResult(values=["HMB", "HMB-IRB"]), + expected_output=_consent_expected(purpose="disease", disease="CA", is_nonprofit=False), ), Case( name="consent-sub-disease", inputs=_mention("type 1 diabetes research consent", Facet.CONSENT_CODE), - # DS-T1D-IRB exists in the index - expected_output=ResolveResult(values=["DS-T1D-IRB"]), + expected_output=_consent_expected(purpose="disease", disease="T1D"), ), Case( name="consent-consented-diabetes", inputs=_mention("diabetes", Facet.CONSENT_CODE), - # Should include GRU (always eligible) and HMB (health/disease) - # plus DS-DIAB family — recall scoring checks all are present - expected_output=ResolveResult( - values=["GRU", "HMB", "DS-DIAB-NPU"] - ), + expected_output=_consent_expected(purpose="disease", disease="DIAB"), ), Case( name="consent-consented-alzheimers", inputs=_mention("Alzheimer's", Facet.CONSENT_CODE), - # GRU always eligible, HMB for disease research - expected_output=ResolveResult(values=["GRU", "HMB"]), + expected_output=_consent_expected(purpose="health"), ), Case( name="consent-disease-only-diabetes", inputs=_mention("diabetes only", Facet.CONSENT_CODE), - # "only" → disease_only=True, should return DS-DIAB* but NOT GRU/HMB - expected_output=ResolveResult(values=["DS-DIAB-NPU"]), + expected_output=_consent_expected(purpose="disease", disease="DIAB", disease_only=True), ), Case( name="consent-disease-only-cancer", inputs=_mention("specifically cancer", Facet.CONSENT_CODE), - # "specifically" → disease_only=True - expected_output=ResolveResult(values=["DS-CA"]), + expected_output=_consent_expected(purpose="disease", disease="CA", disease_only=True), ), # --- GRU vs HMB disambiguation --- Case( name="consent-social-science", inputs=_mention("social science behavioral genetics research", Facet.CONSENT_CODE), - # NOT health/medical → general purpose → GRU only + # NOT health/medical -> general purpose -> GRU only # HMB is restricted to health/medical/biomedical - expected_output=ResolveResult(values=["GRU"]), + expected_output=_consent_expected(purpose="general"), ), Case( name="consent-biomedical", inputs=_mention("biomedical research on aging", Facet.CONSENT_CODE), - # Explicitly biomedical → health purpose → GRU + HMB - expected_output=ResolveResult(values=["GRU", "HMB"]), + # Explicitly biomedical -> health purpose -> GRU + HMB + expected_output=_consent_expected(purpose="health"), ), Case( name="consent-for-profit-health", inputs=_mention("for-profit biomedical health research", Facet.CONSENT_CODE), - # Health purpose + for-profit → GRU + HMB minus NPU variants - expected_output=ResolveResult(values=["GRU", "HMB"]), + # Health purpose + for-profit -> GRU + HMB minus NPU variants + expected_output=_consent_expected(purpose="health", is_nonprofit=False), ), Case( name="consent-population-genetics", inputs=_mention("population genetics, not disease-related", Facet.CONSENT_CODE), - # Explicitly not disease/health → general → GRU only - expected_output=ResolveResult(values=["GRU"]), + # Explicitly not disease/health -> general -> GRU only + expected_output=_consent_expected(purpose="general"), ), ], ) diff --git a/backend/concept_search/extract_agent.py b/backend/concept_search/extract_agent.py index 25e5d9c..0575a2f 100644 --- a/backend/concept_search/extract_agent.py +++ b/backend/concept_search/extract_agent.py @@ -32,7 +32,10 @@ def _get_agent(model: str | None = None) -> Agent[None, ExtractResult]: model, output_type=ExtractResult, system_prompt=_load_prompt(), - model_settings=ModelSettings(anthropic_cache_instructions=True), + model_settings=ModelSettings( + anthropic_cache_instructions=True, + temperature=0.0, + ), ) return _agent diff --git a/backend/concept_search/pipeline.py b/backend/concept_search/pipeline.py index 2f51356..a42a5d5 100644 --- a/backend/concept_search/pipeline.py +++ b/backend/concept_search/pipeline.py @@ -14,8 +14,10 @@ import asyncio import logging +import os import time +from .cache import LRUCache from .extract_agent import run_extract from .index import ConceptIndex, get_index from .models import ( @@ -28,6 +30,12 @@ logger = logging.getLogger(__name__) +pipeline_cache: LRUCache[str, QueryModel] = LRUCache( + name="pipeline_cache", + max_size=int(os.environ.get("PIPELINE_CACHE_MAX_SIZE", "10000")), + ttl_seconds=float(os.environ.get("PIPELINE_CACHE_TTL_SECONDS", "86400")), +) + # --------------------------------------------------------------------------- # Resolve step (parallel per mention) @@ -149,12 +157,12 @@ def _merge( # Main pipeline # --------------------------------------------------------------------------- -async def run_pipeline( +async def _run_pipeline_uncached( query: str, index: ConceptIndex | None = None, model: str | None = None, ) -> QueryModel: - """Run the full 3-agent pipeline on a natural-language query. + """Run the full 3-agent pipeline (no caching). Args: query: The user's natural-language search query. @@ -217,3 +225,27 @@ async def run_pipeline( query_model.message = " ".join(messages) return query_model + + +async def run_pipeline( + query: str, + index: ConceptIndex | None = None, + model: str | None = None, +) -> QueryModel: + """Run the full 3-agent pipeline on a natural-language query. + + Results are cached by normalized query string. A cache hit skips + all three agents (extract, resolve, structure). + + Args: + query: The user's natural-language search query. + index: ConceptIndex to use. If None, uses the shared singleton. + model: Override the model for all agents. + + Returns: + Structured QueryModel with resolved mentions and boolean logic. + """ + key = query.strip().lower() + return await pipeline_cache.get_or_compute( + key, lambda: _run_pipeline_uncached(query, index, model) + ) diff --git a/backend/concept_search/resolve_agent.py b/backend/concept_search/resolve_agent.py index 4f3377b..230a86c 100644 --- a/backend/concept_search/resolve_agent.py +++ b/backend/concept_search/resolve_agent.py @@ -2,12 +2,14 @@ from __future__ import annotations +import os import threading from pathlib import Path from pydantic_ai import Agent, RunContext from pydantic_ai.settings import ModelSettings +from .cache import LRUCache from .consent_logic import compute_eligible_codes, resolve_disease_name from .index import ConceptIndex from .models import ConceptMatch, RawMention, ResolveResult @@ -19,6 +21,12 @@ _agent_model: str | None = None _lock = threading.Lock() +resolve_cache: LRUCache[tuple[str, str], ResolveResult] = LRUCache( + name="resolve_cache", + max_size=int(os.environ.get("RESOLVE_CACHE_MAX_SIZE", "10000")), + ttl_seconds=float(os.environ.get("RESOLVE_CACHE_TTL_SECONDS", "86400")), +) + def _load_prompt() -> str: return _PROMPT_PATH.read_text() @@ -38,7 +46,7 @@ def _get_agent(model: str | None = None) -> Agent[ConceptIndex, ResolveResult]: model_settings=ModelSettings( anthropic_cache_instructions=True, anthropic_cache_tool_definitions=True, - temperature=0.2, + temperature=0.0, ), ) @@ -231,12 +239,12 @@ def get_measurement_category_concepts( return _agent -async def run_resolve( +async def _run_resolve_uncached( mention: RawMention, index: ConceptIndex, model: str | None = None, ) -> ResolveResult: - """Resolve a single raw mention to canonical index values. + """Call the LLM to resolve a mention (no caching). Args: mention: The raw mention to resolve. @@ -250,3 +258,27 @@ async def run_resolve( prompt = f"Resolve this mention:\n- text: {mention.text}\n- facet: {mention.facet.value}" result = await agent.run(prompt, deps=index) return result.output + + +async def run_resolve( + mention: RawMention, + index: ConceptIndex, + model: str | None = None, +) -> ResolveResult: + """Resolve a single raw mention to canonical index values. + + Results are cached by ``(facet, normalized_text)`` to avoid redundant + LLM calls for repeated mentions. + + Args: + mention: The raw mention to resolve. + index: ConceptIndex to search against. + model: Override the model (default: Haiku). + + Returns: + ResolveResult with canonical value(s). + """ + key = (mention.facet.value, mention.text.strip().lower()) + return await resolve_cache.get_or_compute( + key, lambda: _run_resolve_uncached(mention, index, model) + ) diff --git a/backend/tests/test_cache.py b/backend/tests/test_cache.py new file mode 100644 index 0000000..d80c1c9 --- /dev/null +++ b/backend/tests/test_cache.py @@ -0,0 +1,221 @@ +"""Unit tests for the generic LRU cache.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from concept_search.cache import LRUCache, _all_caches, clear_all + + +@pytest.fixture() +def cache() -> LRUCache[str, str]: + """Return a small, short-TTL cache for testing.""" + c = LRUCache[str, str](name="test", max_size=3, ttl_seconds=1.0) + yield c + # Remove from global registry so tests don't leak + _all_caches.remove(c) + + +async def _compute(value: str) -> str: + """Trivial async compute function.""" + return value + + +@pytest.mark.asyncio() +async def test_cache_hit(cache: LRUCache[str, str]) -> None: + """Second call for the same key should return cached result.""" + call_count = 0 + + async def counted() -> str: + nonlocal call_count + call_count += 1 + return "hello" + + r1 = await cache.get_or_compute("k", counted) + r2 = await cache.get_or_compute("k", counted) + assert r1 == r2 == "hello" + assert call_count == 1 + assert cache.hits == 1 + assert cache.misses == 1 + + +@pytest.mark.asyncio() +async def test_different_keys_are_separate( + cache: LRUCache[str, str], +) -> None: + """Different keys should produce separate cache entries.""" + call_count = 0 + + async def counted() -> str: + nonlocal call_count + call_count += 1 + return "v" + + await cache.get_or_compute("a", counted) + await cache.get_or_compute("b", counted) + assert call_count == 2 + assert cache.misses == 2 + + +@pytest.mark.asyncio() +async def test_ttl_expiration() -> None: + """Entries should expire after TTL seconds.""" + cache = LRUCache[str, str](name="ttl-test", max_size=100, ttl_seconds=0.05) + call_count = 0 + + async def counted() -> str: + nonlocal call_count + call_count += 1 + return "v" + + await cache.get_or_compute("k", counted) + await asyncio.sleep(0.1) + await cache.get_or_compute("k", counted) + assert call_count == 2 + assert cache.misses == 2 + _all_caches.remove(cache) + + +@pytest.mark.asyncio() +async def test_lru_eviction(cache: LRUCache[str, str]) -> None: + """When max_size is reached, the oldest entry should be evicted.""" + for k in ["a", "b", "c"]: + await cache.get_or_compute(k, lambda: _compute(k)) + assert len(cache._cache) == 3 + + await cache.get_or_compute("d", lambda: _compute("d")) + assert len(cache._cache) == 3 + assert "a" not in cache._cache + + +@pytest.mark.asyncio() +async def test_lru_access_refreshes(cache: LRUCache[str, str]) -> None: + """Accessing an entry should move it to the end (most recent).""" + for k in ["a", "b", "c"]: + await cache.get_or_compute(k, lambda: _compute(k)) + # Access "a" to refresh it + await cache.get_or_compute("a", lambda: _compute("a")) + # Adding "d" should now evict "b" (the oldest untouched) + await cache.get_or_compute("d", lambda: _compute("d")) + assert "a" in cache._cache + assert "b" not in cache._cache + + +@pytest.mark.asyncio() +async def test_in_flight_deduplication(cache: LRUCache[str, str]) -> None: + """Concurrent computes for the same key should run only once.""" + call_count = 0 + + async def slow() -> str: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.1) + return "v" + + results = await asyncio.gather( + cache.get_or_compute("k", slow), + cache.get_or_compute("k", slow), + cache.get_or_compute("k", slow), + ) + assert call_count == 1 + assert all(r == "v" for r in results) + + +@pytest.mark.asyncio() +async def test_clear(cache: LRUCache[str, str]) -> None: + """clear() should empty the cache and reset counters.""" + await cache.get_or_compute("a", lambda: _compute("a")) + await cache.get_or_compute("a", lambda: _compute("a")) + assert cache.hits == 1 + n = await cache.clear() + assert n == 1 + assert cache.hits == 0 + assert cache.misses == 0 + assert len(cache._cache) == 0 + + +@pytest.mark.asyncio() +async def test_stats(cache: LRUCache[str, str]) -> None: + """stats property should report accurate metrics.""" + await cache.get_or_compute("a", lambda: _compute("a")) + await cache.get_or_compute("a", lambda: _compute("a")) + await cache.get_or_compute("b", lambda: _compute("b")) + stats = cache.stats + assert stats["hits"] == 1 + assert stats["misses"] == 2 + assert stats["size"] == 2 + assert stats["hit_rate"] == pytest.approx(0.333, abs=0.01) + + +@pytest.mark.asyncio() +async def test_tuple_keys() -> None: + """Cache should work with tuple keys (used by resolve cache).""" + cache = LRUCache[tuple[str, str], str](name="tuple-test", max_size=10) + r1 = await cache.get_or_compute(("focus", "diabetes"), lambda: _compute("v1")) + r2 = await cache.get_or_compute(("focus", "diabetes"), lambda: _compute("v2")) + r3 = await cache.get_or_compute(("measurement", "diabetes"), lambda: _compute("v3")) + assert r1 == r2 == "v1" # cache hit + assert r3 == "v3" # different key + assert cache.hits == 1 + assert cache.misses == 2 + _all_caches.remove(cache) + + +@pytest.mark.asyncio() +async def test_compute_exception_not_cached(cache: LRUCache[str, str]) -> None: + """Failed computes should not be cached, and retries should work.""" + + async def failing() -> str: + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + await cache.get_or_compute("k", failing) + + assert len(cache._cache) == 0 + assert "k" not in cache._in_flight + + # Retry should succeed + r = await cache.get_or_compute("k", lambda: _compute("ok")) + assert r == "ok" + assert len(cache._cache) == 1 + + +@pytest.mark.asyncio() +async def test_in_flight_exception_propagates(cache: LRUCache[str, str]) -> None: + """When the owner fails, waiters should retry (not get a stale error).""" + call_count = 0 + + async def slow_fail() -> str: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.05) + raise RuntimeError("boom") + + results = await asyncio.gather( + cache.get_or_compute("k", slow_fail), + cache.get_or_compute("k", slow_fail), + return_exceptions=True, + ) + + # Owner fails; waiter falls through and also calls compute + assert all(isinstance(r, RuntimeError) for r in results) + assert len(cache._cache) == 0 + + +@pytest.mark.asyncio() +async def test_clear_all() -> None: + """clear_all() should clear all registered caches.""" + c1 = LRUCache[str, str](name="c1", max_size=10) + c2 = LRUCache[str, str](name="c2", max_size=10) + await c1.get_or_compute("a", lambda: _compute("a")) + await c2.get_or_compute("b", lambda: _compute("b")) + + results = await clear_all() + assert results["c1"] == 1 + assert results["c2"] == 1 + assert len(c1._cache) == 0 + assert len(c2._cache) == 0 + _all_caches.remove(c1) + _all_caches.remove(c2)