Skip to content

Commit 6ddd055

Browse files
committed
Make potentially blocking calls async
1 parent 138d49e commit 6ddd055

File tree

4 files changed

+53
-47
lines changed

4 files changed

+53
-47
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"openai",
1919
"pydantic-ai",
2020
"shapely",
21+
"aiohttp",
2122
]
2223

2324
[tool.setuptools]

stac_search/agents/collections_search.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Search module for STAC Natural Query
33
"""
44

5+
import asyncio
56
import logging
67
import os
78
import time
@@ -97,10 +98,11 @@ async def collection_search(
9798
logger.info(f"Model loading time: {load_model_time - start_time:.4f} seconds")
9899

99100
# Generate query embedding
100-
query_embedding = catalog_manager.model.encode([query])
101+
query_embedding = await asyncio.to_thread(catalog_manager.model.encode, [query])
101102

102103
# Search vector database
103-
results = collection.query(
104+
results = await asyncio.to_thread(
105+
collection.query,
104106
query_embeddings=query_embedding.tolist(),
105107
n_results=top_k * 2, # Get more results initially for better reranking
106108
)
@@ -131,6 +133,4 @@ async def main():
131133

132134

133135
if __name__ == "__main__":
134-
import asyncio
135-
136136
asyncio.run(main())

stac_search/agents/items_search.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from dataclasses import dataclass
66
from pprint import pformat
77
import time
8+
import asyncio
89
from typing import List, Dict, Any, Union
910

10-
import requests
11+
import aiohttp
1112
from pydantic_ai import Agent, RunContext
1213
from pystac_client import Client
1314
from pydantic import BaseModel, ConfigDict
@@ -259,15 +260,13 @@ async def construct_cql2_filter(ctx: RunContext[Context]) -> FilterExpr | None:
259260
return await cql2_filter_agent.run(ctx.deps.query)
260261

261262

262-
def get_polygon_from_geodini(location: str):
263+
async def get_polygon_from_geodini(location: str):
263264
geodini_api = f"{GEODINI_API}/search_complex"
264-
response = requests.get(
265-
geodini_api,
266-
params={"query": location},
267-
)
268-
result = response.json().get("result", None)
269-
if result:
270-
return result.get("geometry", None)
265+
async with aiohttp.ClientSession() as session:
266+
async with session.get(geodini_api, params={"query": location}) as response:
267+
result = (await response.json()).get("result", None)
268+
if result:
269+
return result.get("geometry", None)
271270
return None
272271

273272

@@ -300,10 +299,12 @@ async def item_search(ctx: Context) -> ItemSearchResult:
300299
# If no specific collections were found, use the default target collections
301300
default_target_collections = DEFAULT_TARGET_COLLECTIONS
302301
# check that default_target_collections exist in the catalog
303-
all_collection_ids = [
304-
collection.id
305-
for collection in Client.open(catalog_url_to_use).get_collections()
306-
]
302+
all_collection_ids = await asyncio.to_thread(
303+
lambda: [
304+
collection.id
305+
for collection in Client.open(catalog_url_to_use).get_collections()
306+
]
307+
)
307308
default_target_collections = [
308309
collection_id
309310
for collection_id in default_target_collections
@@ -339,7 +340,7 @@ async def item_search(ctx: Context) -> ItemSearchResult:
339340
f"Params formulation time: {params_formulation_time - query_formulation_time} seconds"
340341
)
341342

342-
polygon = get_polygon_from_geodini(results.data.location)
343+
polygon = await get_polygon_from_geodini(results.data.location)
343344
if polygon:
344345
logger.info(f"Found polygon for {results.data.location}")
345346
params["intersects"] = polygon
@@ -349,25 +350,21 @@ async def item_search(ctx: Context) -> ItemSearchResult:
349350
items=None, search_params=params, aoi=None, explanation=explanation
350351
)
351352
geocoding_time = time.time()
352-
logger.info(
353-
f"Geocoding time: {geocoding_time - params_formulation_time} seconds"
354-
)
353+
logger.info(f"Geocoding time: {geocoding_time - params_formulation_time} seconds")
355354

356355
if ctx.return_search_params_only:
357356
logger.info("Returning STAC query parameters only")
358357
total_time = time.time() - start_time
359-
logger.info(
360-
f"Total time: {total_time} seconds"
361-
)
358+
logger.info(f"Total time: {total_time} seconds")
362359
return ItemSearchResult(
363360
search_params=params, aoi=polygon, explanation=explanation
364361
)
365362

366-
items = list(client.search(**params).items_as_dicts())
367-
search_time = time.time()
368-
logger.info(
369-
f"Search time: {search_time - geocoding_time} seconds"
363+
items = await asyncio.to_thread(
364+
lambda: list(client.search(**params).items_as_dicts())
370365
)
366+
search_time = time.time()
367+
logger.info(f"Search time: {search_time - geocoding_time} seconds")
371368
total_time = time.time() - start_time
372369
logger.info(f"Total time: {total_time} seconds")
373370
return ItemSearchResult(
@@ -382,6 +379,4 @@ async def main():
382379

383380

384381
if __name__ == "__main__":
385-
import asyncio
386-
387382
asyncio.run(main())

stac_search/catalog_manager.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Catalog Manager for STAC Natural Query - handles dynamic catalog loading and management
33
"""
44

5+
import asyncio
56
import hashlib
67
import logging
78
import os
@@ -52,38 +53,46 @@ def catalog_exists(self, catalog_url: str) -> bool:
5253
logger.error(f"Error checking catalog existence: {e}")
5354
return False
5455

55-
def validate_catalog_url(self, catalog_url: str) -> bool:
56+
async def validate_catalog_url(self, catalog_url: str) -> bool:
5657
"""Validate that the catalog URL is accessible and is a valid STAC catalog"""
5758
try:
58-
stac_client = Client.open(catalog_url)
59-
# Try to get at least one collection to verify it's a valid catalog
60-
collections = list(stac_client.collection_search().collections())
61-
return len(collections) > 0
59+
60+
def _validate():
61+
stac_client = Client.open(catalog_url)
62+
# Try to get at least one collection to verify it's a valid catalog
63+
collections = list(stac_client.collection_search().collections())
64+
return len(collections) > 0
65+
66+
return await asyncio.to_thread(_validate)
6267
except Exception as e:
6368
logger.error(f"Invalid catalog URL {catalog_url}: {e}")
6469
return False
6570

66-
def fetch_collections(self, stac_client: Client) -> list:
71+
async def fetch_collections(self, stac_client: Client) -> list:
6772
"""Fetch STAC collections using pystac-client"""
6873
try:
69-
collections = stac_client.collection_search().collections()
70-
return list(collections)
74+
75+
def _fetch():
76+
collections = stac_client.collection_search().collections()
77+
return list(collections)
78+
79+
return await asyncio.to_thread(_fetch)
7180
except Exception as e:
7281
logger.error(f"Error fetching collections: {e}")
7382
return []
7483

75-
def generate_embeddings(self, collections: list) -> list:
84+
async def generate_embeddings(self, collections: list) -> list:
7685
"""Generate embeddings for each collection (title + description)"""
7786
texts = []
7887
for collection in collections:
7988
title = getattr(collection, "title", "") or ""
8089
description = getattr(collection, "description", "") or ""
8190
texts.append(f"{title} {description}")
8291

83-
embeddings = self.model.encode(texts)
92+
embeddings = await asyncio.to_thread(self.model.encode, texts)
8493
return embeddings
8594

86-
def store_in_vector_db(self, collections: list, chroma_collection) -> None:
95+
async def store_in_vector_db(self, collections: list, chroma_collection) -> None:
8796
"""Store embeddings in ChromaDB"""
8897
if not collections:
8998
logger.warning("No collections to store")
@@ -98,9 +107,10 @@ def store_in_vector_db(self, collections: list, chroma_collection) -> None:
98107
}
99108
metadatas.append(metadata)
100109

101-
embeddings = self.generate_embeddings(collections)
110+
embeddings = await self.generate_embeddings(collections)
102111

103-
chroma_collection.add(
112+
await asyncio.to_thread(
113+
chroma_collection.add,
104114
ids=[str(i) for i in range(len(collections))],
105115
embeddings=embeddings,
106116
metadatas=metadatas,
@@ -110,7 +120,7 @@ async def load_catalog(self, catalog_url: str) -> Dict[str, Any]:
110120
"""Load and index a catalog if it doesn't exist"""
111121
try:
112122
# Validate catalog URL first
113-
if not self.validate_catalog_url(catalog_url):
123+
if not await self.validate_catalog_url(catalog_url):
114124
return {
115125
"success": False,
116126
"error": f"Invalid or inaccessible catalog URL: {catalog_url}",
@@ -127,8 +137,8 @@ async def load_catalog(self, catalog_url: str) -> Dict[str, Any]:
127137

128138
# Load the catalog
129139
logger.info(f"Loading catalog from {catalog_url}")
130-
stac_client = Client.open(catalog_url)
131-
collections = self.fetch_collections(stac_client)
140+
stac_client = await asyncio.to_thread(Client.open, catalog_url)
141+
collections = await self.fetch_collections(stac_client)
132142

133143
if not collections:
134144
return {
@@ -143,7 +153,7 @@ async def load_catalog(self, catalog_url: str) -> Dict[str, Any]:
143153
)
144154

145155
# Store in vector database
146-
self.store_in_vector_db(collections, chroma_collection)
156+
await self.store_in_vector_db(collections, chroma_collection)
147157

148158
logger.info(
149159
f"Successfully indexed {len(collections)} collections from {catalog_url}"

0 commit comments

Comments
 (0)