Skip to content

Commit 88e3794

Browse files
authored
Cache expensive calls to LLMs (#5)
1 parent 6ddd055 commit 88e3794

File tree

6 files changed

+144
-23
lines changed

6 files changed

+144
-23
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies = [
1919
"pydantic-ai",
2020
"shapely",
2121
"aiohttp",
22+
"cachetools>=5.0.0",
2223
]
2324

2425
[tool.setuptools]

stac_search/agents/collections_search.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from pydantic_ai import Agent
1414
from stac_search.catalog_manager import CatalogManager
15+
from stac_search.cache import async_cached, embedding_cache, agent_cache
1516

1617

1718
logger = logging.getLogger(__name__)
@@ -59,6 +60,19 @@ class RankedCollections:
5960
)
6061

6162

63+
@async_cached(embedding_cache)
64+
async def _generate_query_embedding(catalog_manager, query: str):
65+
"""Generate cached embedding for query string"""
66+
return await asyncio.to_thread(catalog_manager.model.encode, [query])
67+
68+
69+
@async_cached(agent_cache)
70+
async def _run_rerank_agent(user_prompt: str) -> RankedCollections:
71+
"""Run the rerank agent with caching"""
72+
result = await rerank_agent.run(user_prompt)
73+
return result.data
74+
75+
6276
async def collection_search(
6377
query: str,
6478
top_k: int = 5,
@@ -98,7 +112,7 @@ async def collection_search(
98112
logger.info(f"Model loading time: {load_model_time - start_time:.4f} seconds")
99113

100114
# Generate query embedding
101-
query_embedding = await asyncio.to_thread(catalog_manager.model.encode, [query])
115+
query_embedding = await _generate_query_embedding(catalog_manager, query)
102116

103117
# Search vector database
104118
results = await asyncio.to_thread(
@@ -122,9 +136,9 @@ async def collection_search(
122136
{collections_text}
123137
"""
124138

125-
agent_result = await rerank_agent.run(user_prompt)
139+
agent_result = await _run_rerank_agent(user_prompt)
126140

127-
return agent_result.data.results
141+
return agent_result.results
128142

129143

130144
async def main():

stac_search/agents/items_search.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import json
33
import logging
44
import os
5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, asdict
66
from pprint import pformat
77
import time
88
import asyncio
99
from typing import List, Dict, Any, Union
10-
1110
import aiohttp
1211
from pydantic_ai import Agent, RunContext
1312
from pystac_client import Client
@@ -17,6 +16,7 @@
1716
collection_search,
1817
CollectionWithExplanation,
1918
)
19+
from stac_search.cache import async_cached, agent_cache, geocoding_cache
2020

2121

2222
GEODINI_API = os.getenv("GEODINI_API", "https://geodini.k8s.labs.ds.io")
@@ -70,6 +70,12 @@ def search_items_agent_system_prompt():
7070
return f"The current date is {date.today()}"
7171

7272

73+
@async_cached(agent_cache)
74+
async def _run_search_items_agent(query: str, deps: dict) -> ItemSearchParams:
75+
result = await search_items_agent.run(query, deps=Context(**deps))
76+
return result.data
77+
78+
7379
@dataclass
7480
class CollectionQuery:
7581
query: str
@@ -108,15 +114,22 @@ class CollectionSearchResult:
108114
collections: List[CollectionWithExplanation]
109115

110116

117+
@async_cached(agent_cache)
118+
async def _run_collection_query_framing_agent(query: str) -> CollectionQuery:
119+
result = await collection_query_framing_agent.run(query)
120+
return result.data
121+
122+
123+
@async_cached(agent_cache)
111124
async def search_collections(
112125
query: str, catalog_url: str = None
113126
) -> CollectionSearchResult | None:
114127
logger.info("Searching for relevant collections ...")
115-
collection_query = await collection_query_framing_agent.run(query)
116-
logger.info(f"Framed collection query: {collection_query.data.query}")
117-
if collection_query.data.is_specific:
128+
collection_query = await _run_collection_query_framing_agent(query)
129+
logger.info(f"Framed collection query: {collection_query.query}")
130+
if collection_query.is_specific:
118131
collections = await collection_search(
119-
collection_query.data.query, catalog_url=catalog_url
132+
collection_query.query, catalog_url=catalog_url
120133
)
121134
return CollectionSearchResult(collections=collections)
122135
else:
@@ -143,10 +156,15 @@ class GeocodingResult:
143156
)
144157

145158

159+
@async_cached(geocoding_cache)
160+
async def _run_geocoding_agent(query: str) -> GeocodingResult:
161+
result = await geocoding_agent.run(query)
162+
return result.data
163+
164+
146165
@search_items_agent.tool
147166
async def set_spatial_extent(ctx: RunContext[Context]) -> GeocodingResult:
148-
result = await geocoding_agent.run(ctx.deps.query)
149-
return result.data
167+
return await _run_geocoding_agent(ctx.deps.query)
150168

151169

152170
@dataclass
@@ -170,10 +188,15 @@ def temporal_range_agent_system_prompt():
170188
return f"The current date is {date.today()}"
171189

172190

191+
@async_cached(agent_cache)
192+
async def _run_temporal_range_agent(query: str) -> TemporalRangeResult:
193+
result = await temporal_range_agent.run(query)
194+
return result.data
195+
196+
173197
@search_items_agent.tool
174198
async def set_temporal_range(ctx: RunContext[Context]) -> TemporalRangeResult:
175-
result = await temporal_range_agent.run(ctx.deps.query)
176-
return result.data
199+
return await _run_temporal_range_agent(ctx.deps.query)
177200

178201

179202
class PropertyRef(BaseModel):
@@ -255,11 +278,18 @@ class FilterExpr(BaseModel):
255278
)
256279

257280

281+
@async_cached(agent_cache)
282+
async def _run_cql2_filter_agent(query: str) -> FilterExpr | None:
283+
result = await cql2_filter_agent.run(query)
284+
return result.data
285+
286+
258287
@search_items_agent.tool
259288
async def construct_cql2_filter(ctx: RunContext[Context]) -> FilterExpr | None:
260-
return await cql2_filter_agent.run(ctx.deps.query)
289+
return await _run_cql2_filter_agent(ctx.deps.query)
261290

262291

292+
@async_cached(geocoding_cache)
263293
async def get_polygon_from_geodini(location: str):
264294
geodini_api = f"{GEODINI_API}/search_complex"
265295
async with aiohttp.ClientSession() as session:
@@ -281,8 +311,8 @@ class ItemSearchResult:
281311
async def item_search(ctx: Context) -> ItemSearchResult:
282312
start_time = time.time()
283313
# formulate the query to be used for the search
284-
results = await search_items_agent.run(
285-
f"Find items for the query: {ctx.query}", deps=ctx
314+
results = await _run_search_items_agent(
315+
query=f"Find items for the query: {ctx.query}", deps=asdict(ctx)
286316
)
287317
query_formulation_time = time.time()
288318
logger.info(
@@ -330,8 +360,8 @@ async def item_search(ctx: Context) -> ItemSearchResult:
330360
params = {
331361
"max_items": 20,
332362
"collections": collections_to_search,
333-
"datetime": results.data.datetime,
334-
"filter": results.data.filter,
363+
"datetime": results.datetime,
364+
"filter": results.filter,
335365
}
336366

337367
logger.info(f"Searching with params: {params}")
@@ -340,12 +370,12 @@ async def item_search(ctx: Context) -> ItemSearchResult:
340370
f"Params formulation time: {params_formulation_time - query_formulation_time} seconds"
341371
)
342372

343-
polygon = await get_polygon_from_geodini(results.data.location)
373+
polygon = await get_polygon_from_geodini(results.location)
344374
if polygon:
345-
logger.info(f"Found polygon for {results.data.location}")
375+
logger.info(f"Found polygon for {results.location}")
346376
params["intersects"] = polygon
347377
else:
348-
explanation += f"\n\n No polygon found for {results.data.location}. "
378+
explanation += f"\n\n No polygon found for {results.location}. "
349379
return ItemSearchResult(
350380
items=None, search_params=params, aoi=None, explanation=explanation
351381
)

stac_search/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22
FastAPI server for STAC Natural Query
33
"""
44

5+
import logging
6+
from typing import Optional
7+
58
from fastapi import FastAPI, HTTPException
69
from fastapi.middleware.cors import CORSMiddleware
710
from pydantic import BaseModel
8-
from typing import Optional
911
import uvicorn
1012

1113
from stac_search.agents.collections_search import collection_search
1214
from stac_search.agents.items_search import item_search, Context as ItemSearchContext
1315

16+
logger = logging.getLogger(__name__)
17+
1418
# Initialize FastAPI app
1519
app = FastAPI(
1620
title="STAC Natural Query API",
@@ -65,6 +69,7 @@ async def search_items(request: STACItemsRequest):
6569
results = await item_search(ctx)
6670
return {"results": results}
6771
except Exception as e:
72+
logger.exception(e)
6873
raise HTTPException(status_code=500, detail=str(e))
6974

7075

stac_search/cache.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
Caching module for STAC Natural Query - handles various caching strategies
3+
"""
4+
5+
import asyncio
6+
import logging
7+
from functools import wraps
8+
9+
from cachetools.keys import hashkey
10+
from cachetools import TTLCache
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
geocoding_cache = TTLCache(maxsize=100, ttl=86400) # 24 hours - locations don't change
16+
embedding_cache = TTLCache(maxsize=100, ttl=86400) # 24 hours - embeddings are stable
17+
agent_cache = TTLCache(maxsize=100, ttl=3600) # 1 hour - agent results cache
18+
19+
20+
def _freeze(obj):
21+
if isinstance(obj, dict):
22+
# sort items to make order deterministic
23+
return frozenset((k, _freeze(v)) for k, v in sorted(obj.items()))
24+
if isinstance(obj, (list, tuple)):
25+
return tuple(_freeze(v) for v in obj)
26+
if isinstance(obj, set):
27+
return frozenset(_freeze(v) for v in obj)
28+
return obj # assume primitive (int, str, etc.)
29+
30+
31+
def async_cached(cache):
32+
lock = asyncio.Lock()
33+
34+
def decorator(fn):
35+
@wraps(fn)
36+
async def wrapper(*args, **kwargs):
37+
# freeze each arg/kwarg
38+
fargs = tuple(_freeze(a) for a in args)
39+
fkwargs = {k: _freeze(v) for k, v in kwargs.items()}
40+
key = hashkey(f"{fn.__name__}", *fargs, **fkwargs)
41+
if key in cache:
42+
return cache[key]
43+
async with lock:
44+
if key in cache:
45+
return cache[key]
46+
result = await fn(*args, **kwargs)
47+
cache[key] = result
48+
return result
49+
50+
return wrapper
51+
52+
return decorator
53+
54+
55+
def clear_all_caches():
56+
"""
57+
Clear all caches
58+
"""
59+
logger.info("Clearing all caches")
60+
geocoding_cache.clear()
61+
embedding_cache.clear()
62+
agent_cache.clear()

stac_search/catalog_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
from pystac_client import Client
1212
from sentence_transformers import SentenceTransformer
1313

14+
from stac_search.cache import async_cached, embedding_cache
15+
16+
1417
logger = logging.getLogger(__name__)
1518

1619
# Constants
1720
MODEL_NAME = "all-MiniLM-L6-v2"
1821
DATA_PATH = os.environ.get("DATA_PATH", "data/chromadb")
22+
MODEL = SentenceTransformer(MODEL_NAME)
1923

2024

2125
class CatalogManager:
@@ -24,11 +28,15 @@ class CatalogManager:
2428
def __init__(self, data_path: str = DATA_PATH, model_name: str = MODEL_NAME):
2529
self.data_path = data_path
2630
self.model_name = model_name
27-
self.model = SentenceTransformer(model_name)
2831
self.client = chromadb.PersistentClient(path=data_path)
2932

33+
@property
34+
def model(self):
35+
return MODEL
36+
3037
def _get_catalog_name(self, catalog_url: str) -> str:
3138
"""Generate a unique catalog name from URL"""
39+
logger.info(f"Generating catalog name for {catalog_url}")
3240
# Create a hash of the URL for consistent naming
3341
url_hash = hashlib.md5(catalog_url.encode()).hexdigest()[:8]
3442
# Clean URL for readability
@@ -81,6 +89,7 @@ def _fetch():
8189
logger.error(f"Error fetching collections: {e}")
8290
return []
8391

92+
@async_cached(embedding_cache)
8493
async def generate_embeddings(self, collections: list) -> list:
8594
"""Generate embeddings for each collection (title + description)"""
8695
texts = []

0 commit comments

Comments
 (0)