22import json
33import logging
44import os
5- from dataclasses import dataclass
5+ from dataclasses import dataclass , asdict
66from pprint import pformat
77import time
88import asyncio
99from typing import List , Dict , Any , Union
10-
1110import aiohttp
1211from pydantic_ai import Agent , RunContext
1312from pystac_client import Client
1716 collection_search ,
1817 CollectionWithExplanation ,
1918)
19+ from stac_search .cache import async_cached , agent_cache , geocoding_cache
2020
2121
2222GEODINI_API = os .getenv ("GEODINI_API" , "https://geodini.k8s.labs.ds.io" )
@@ -70,6 +70,12 @@ def search_items_agent_system_prompt():
7070 return f"The current date is { date .today ()} "
7171
7272
73+ @async_cached (agent_cache )
74+ async def _run_search_items_agent (query : str , deps : dict ) -> ItemSearchParams :
75+ result = await search_items_agent .run (query , deps = Context (** deps ))
76+ return result .data
77+
78+
7379@dataclass
7480class CollectionQuery :
7581 query : str
@@ -108,15 +114,22 @@ class CollectionSearchResult:
108114 collections : List [CollectionWithExplanation ]
109115
110116
117+ @async_cached (agent_cache )
118+ async def _run_collection_query_framing_agent (query : str ) -> CollectionQuery :
119+ result = await collection_query_framing_agent .run (query )
120+ return result .data
121+
122+
123+ @async_cached (agent_cache )
111124async def search_collections (
112125 query : str , catalog_url : str = None
113126) -> CollectionSearchResult | None :
114127 logger .info ("Searching for relevant collections ..." )
115- collection_query = await collection_query_framing_agent . run (query )
116- logger .info (f"Framed collection query: { collection_query .data . query } " )
117- if collection_query .data . is_specific :
128+ collection_query = await _run_collection_query_framing_agent (query )
129+ logger .info (f"Framed collection query: { collection_query .query } " )
130+ if collection_query .is_specific :
118131 collections = await collection_search (
119- collection_query .data . query , catalog_url = catalog_url
132+ collection_query .query , catalog_url = catalog_url
120133 )
121134 return CollectionSearchResult (collections = collections )
122135 else :
@@ -143,10 +156,15 @@ class GeocodingResult:
143156)
144157
145158
159+ @async_cached (geocoding_cache )
160+ async def _run_geocoding_agent (query : str ) -> GeocodingResult :
161+ result = await geocoding_agent .run (query )
162+ return result .data
163+
164+
146165@search_items_agent .tool
147166async def set_spatial_extent (ctx : RunContext [Context ]) -> GeocodingResult :
148- result = await geocoding_agent .run (ctx .deps .query )
149- return result .data
167+ return await _run_geocoding_agent (ctx .deps .query )
150168
151169
152170@dataclass
@@ -170,10 +188,15 @@ def temporal_range_agent_system_prompt():
170188 return f"The current date is { date .today ()} "
171189
172190
191+ @async_cached (agent_cache )
192+ async def _run_temporal_range_agent (query : str ) -> TemporalRangeResult :
193+ result = await temporal_range_agent .run (query )
194+ return result .data
195+
196+
173197@search_items_agent .tool
174198async def set_temporal_range (ctx : RunContext [Context ]) -> TemporalRangeResult :
175- result = await temporal_range_agent .run (ctx .deps .query )
176- return result .data
199+ return await _run_temporal_range_agent (ctx .deps .query )
177200
178201
179202class PropertyRef (BaseModel ):
@@ -255,11 +278,18 @@ class FilterExpr(BaseModel):
255278)
256279
257280
281+ @async_cached (agent_cache )
282+ async def _run_cql2_filter_agent (query : str ) -> FilterExpr | None :
283+ result = await cql2_filter_agent .run (query )
284+ return result .data
285+
286+
258287@search_items_agent .tool
259288async def construct_cql2_filter (ctx : RunContext [Context ]) -> FilterExpr | None :
260- return await cql2_filter_agent . run (ctx .deps .query )
289+ return await _run_cql2_filter_agent (ctx .deps .query )
261290
262291
292+ @async_cached (geocoding_cache )
263293async def get_polygon_from_geodini (location : str ):
264294 geodini_api = f"{ GEODINI_API } /search_complex"
265295 async with aiohttp .ClientSession () as session :
@@ -281,8 +311,8 @@ class ItemSearchResult:
281311async def item_search (ctx : Context ) -> ItemSearchResult :
282312 start_time = time .time ()
283313 # formulate the query to be used for the search
284- results = await search_items_agent . run (
285- f"Find items for the query: { ctx .query } " , deps = ctx
314+ results = await _run_search_items_agent (
315+ query = f"Find items for the query: { ctx .query } " , deps = asdict ( ctx )
286316 )
287317 query_formulation_time = time .time ()
288318 logger .info (
@@ -330,8 +360,8 @@ async def item_search(ctx: Context) -> ItemSearchResult:
330360 params = {
331361 "max_items" : 20 ,
332362 "collections" : collections_to_search ,
333- "datetime" : results .data . datetime ,
334- "filter" : results .data . filter ,
363+ "datetime" : results .datetime ,
364+ "filter" : results .filter ,
335365 }
336366
337367 logger .info (f"Searching with params: { params } " )
@@ -340,12 +370,12 @@ async def item_search(ctx: Context) -> ItemSearchResult:
340370 f"Params formulation time: { params_formulation_time - query_formulation_time } seconds"
341371 )
342372
343- polygon = await get_polygon_from_geodini (results .data . location )
373+ polygon = await get_polygon_from_geodini (results .location )
344374 if polygon :
345- logger .info (f"Found polygon for { results .data . location } " )
375+ logger .info (f"Found polygon for { results .location } " )
346376 params ["intersects" ] = polygon
347377 else :
348- explanation += f"\n \n No polygon found for { results .data . location } . "
378+ explanation += f"\n \n No polygon found for { results .location } . "
349379 return ItemSearchResult (
350380 items = None , search_params = params , aoi = None , explanation = explanation
351381 )
0 commit comments