Skip to content

Commit a46c4ad

Browse files
authored
Support dynamically computing embeddings for a catalog (#4)
Users can pass in a STAC catalog url with their search request. The tool will fetch the collections from the mentioned catalog, compute and cache the embeddings and search against the mentioned catalog. This would allow a stac-semantic-search instance to search against any public STAC catalog and not just a fixed one
1 parent 05fe6d2 commit a46c4ad

File tree

9 files changed

+323
-116
lines changed

9 files changed

+323
-116
lines changed

.env.example

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ SMALL_MODEL_NAME="openai:gpt-4.1-mini"
88

99
GEODINI_API="https://geodini.k8s.labs.ds.io"
1010

11-
STAC_CATALOG_NAME="planetarycomputer"
1211
STAC_CATALOG_URL="https://planetarycomputer.microsoft.com/api/stac/v1"

frontend/streamlit_app.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,57 @@
3434
"""
3535
)
3636

37-
# Create input field for the query
38-
query = st.text_input(
39-
"Enter your query",
40-
placeholder="Find imagery over Paris from 2017",
41-
help="Describe what kind of satellite imagery you're looking for",
42-
)
37+
# Create two columns for query and catalog URL
38+
col1, col2 = st.columns([3, 1])
4339

44-
# Add a search button
45-
search_button = st.button("Search")
40+
with col1:
41+
# Create input field for the query
42+
query = st.text_input(
43+
"Enter your query",
44+
placeholder="Find imagery over Paris from 2017",
45+
help="Describe what kind of satellite imagery you're looking for",
46+
)
47+
# Add a search button
48+
search_button = st.button("Search")
4649

50+
with col2:
51+
# Define catalog options
52+
catalog_options = {
53+
"Planetary Computer": "https://planetarycomputer.microsoft.com/api/stac/v1",
54+
"VEDA": "https://openveda.cloud/api/stac",
55+
"E84 Earth Search": "https://earth-search.aws.element84.com/v1",
56+
"DevSeed EOAPI.dev": "https://stac.eoapi.dev",
57+
"Custom URL": "custom",
58+
}
4759

48-
# Function to run the search asynchronously
49-
async def run_search(query):
50-
response = requests.post(
51-
f"{API_URL}/items/search", json={"query": query, "limit": 10}
60+
# Create dropdown for catalog selection
61+
selected_catalog = st.selectbox(
62+
"Select STAC Catalog",
63+
options=list(catalog_options.keys()),
64+
index=0, # Default to Planetary Computer
65+
help="Choose a predefined STAC catalog or select 'Custom URL' to enter your own.",
5266
)
67+
68+
# Handle custom URL input
69+
if selected_catalog == "Custom URL":
70+
catalog_url = st.text_input(
71+
"Enter Custom Catalog URL",
72+
placeholder="https://your-catalog.com/stac/v1",
73+
help="Enter the URL of your custom STAC catalog.",
74+
)
75+
else:
76+
catalog_url = catalog_options[selected_catalog]
77+
# Show the selected URL as read-only info
78+
st.info(f"Using: {catalog_url}")
79+
80+
81+
# Function to run the search asynchronously
82+
async def run_search(query, catalog_url=None):
83+
payload = {"query": query, "limit": 10}
84+
if catalog_url:
85+
payload["catalog_url"] = catalog_url.strip()
86+
87+
response = requests.post(f"{API_URL}/items/search", json=payload)
5388
return response.json()["results"]
5489

5590

@@ -60,7 +95,7 @@ async def run_search(query):
6095
# Run the async search
6196
loop = asyncio.new_event_loop()
6297
asyncio.set_event_loop(loop)
63-
results = loop.run_until_complete(run_search(query))
98+
results = loop.run_until_complete(run_search(query, catalog_url))
6499
items = results["items"]
65100
aoi = results["aoi"]
66101
explanation = results["explanation"]
@@ -212,10 +247,19 @@ async def run_search(query):
212247
"""
213248
Search for satellite imagery using natural language.
214249
215-
**Examples queries:**
250+
**Available STAC Catalogs:**
251+
- **Planetary Computer**: Microsoft's global dataset catalog
252+
- **VEDA**: NASA's Earth science data catalog
253+
- **E84 Earth Search**: Element 84's STAC catalog for Earth observation data on AWS Open Data
254+
- **DevSeed EOAPI.dev**: DevSeed's example STAC catalog
255+
- **Custom URL**: Enter any STAC-compliant catalog URL
256+
257+
The system will automatically index new catalogs on first use.
258+
259+
**Example queries:**
216260
- imagery of Paris from 2017
217261
- Cloud-free satellite data of Georgia the country from 2022
218-
- relatively cloud-free images in 2024 that have RGB visual bands over Longmont, Colorado that can be downloaded via HTTP
262+
- relatively cloud-free images in 2024 over Longmont, Colorado
219263
- images in 2024 over Odisha with cloud cover between 50 to 60%
220264
- NAIP imagery over the state of Washington
221265
- Burn scar imagery of from 2024 over the state of California

helm-chart/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ The init container uses the same STAC catalog configuration as the API:
137137
api:
138138
env:
139139
STAC_CATALOG_URL: "https://planetarycomputer.microsoft.com/api/stac/v1"
140-
STAC_CATALOG_NAME: "planetarycomputer"
141140
142141
initContainer:
143142
enabled: true # Set to false to disable data pre-loading

helm-chart/values.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ api:
4141
PYTHONUNBUFFERED: "1"
4242
HF_HOME: "/app/data/.cache/huggingface"
4343
GEODINI_API: "https://geodini.k8s.labs.ds.io"
44-
STAC_CATALOG_NAME: "planetarycomputer"
4544
STAC_CATALOG_URL: "https://planetarycomputer.microsoft.com/api/stac/v1"
4645
DEFAULT_TARGET_COLLECTIONS: "['landsat-8-c2-l2', 'sentinel-2-l2a']"
4746

stac_search/agents/collections_search.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from pprint import pformat
1010
from typing import List, Dict, Any
1111

12-
import chromadb
1312
from pydantic_ai import Agent
14-
from sentence_transformers import SentenceTransformer
13+
from stac_search.catalog_manager import CatalogManager
1514

1615

1716
logger = logging.getLogger(__name__)
@@ -20,7 +19,6 @@
2019
MODEL_NAME = "all-MiniLM-L6-v2"
2120
DATA_PATH = os.environ.get("DATA_PATH", "data/chromadb")
2221

23-
STAC_CATALOG_NAME = os.getenv("STAC_CATALOG_NAME", "planetarycomputer")
2422
STAC_COLLECTIONS_URL = os.getenv(
2523
"STAC_COLLECTIONS_URL", "https://planetarycomputer.microsoft.com/api/stac/v1"
2624
)
@@ -65,7 +63,7 @@ async def collection_search(
6563
top_k: int = 5,
6664
model_name: str = MODEL_NAME,
6765
data_path: str = DATA_PATH,
68-
stac_catalog_name: str = STAC_CATALOG_NAME,
66+
catalog_url: str = None,
6967
) -> List[CollectionWithExplanation]:
7068
"""
7169
Search for collections and rerank results with explanations
@@ -75,25 +73,31 @@ async def collection_search(
7573
top_k: Maximum number of results to return
7674
model_name: Name of the sentence transformer model to use
7775
data_path: Path to the vector database
78-
stac_catalog_name: Name of the STAC catalog
79-
stac_collections_url: URL of the STAC collections API
76+
catalog_url: URL of the STAC catalog
8077
8178
Returns:
8279
Ranked results with relevance explanations
8380
"""
8481
start_time = time.time()
8582

86-
# Initialize model and database connections
87-
model = SentenceTransformer(model_name)
83+
# Initialize catalog manager
84+
catalog_manager = CatalogManager(data_path=data_path, model_name=model_name)
85+
86+
# If catalog_url is provided, ensure it's loaded
87+
if catalog_url:
88+
load_result = await catalog_manager.load_catalog(catalog_url)
89+
if not load_result["success"]:
90+
logger.error(f"Failed to load catalog: {load_result['error']}")
91+
raise ValueError(f"Failed to load catalog: {load_result['error']}")
92+
93+
# Get the appropriate collection
94+
collection = catalog_manager.get_catalog_collection(catalog_url)
95+
8896
load_model_time = time.time()
8997
logger.info(f"Model loading time: {load_model_time - start_time:.4f} seconds")
9098

91-
client = chromadb.PersistentClient(path=data_path)
92-
collection_name = f"{stac_catalog_name}_collections"
93-
collection = client.get_collection(name=collection_name)
94-
9599
# Generate query embedding
96-
query_embedding = model.encode([query])
100+
query_embedding = catalog_manager.model.encode([query])
97101

98102
# Search vector database
99103
results = collection.query(

stac_search/agents/items_search.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
@dataclass
3434
class Context:
3535
query: str
36+
catalog_url: str | None = None
3637
location: str | None = None
3738
top_k: int = 5
3839
return_search_params_only: bool = False
@@ -105,12 +106,16 @@ class CollectionSearchResult:
105106
collections: List[CollectionWithExplanation]
106107

107108

108-
async def search_collections(query: str) -> CollectionSearchResult | None:
109+
async def search_collections(
110+
query: str, catalog_url: str = None
111+
) -> CollectionSearchResult | None:
109112
logger.info("Searching for relevant collections ...")
110113
collection_query = await collection_query_framing_agent.run(query)
111114
logger.info(f"Framed collection query: {collection_query.data.query}")
112115
if collection_query.data.is_specific:
113-
collections = await collection_search(collection_query.data.query)
116+
collections = await collection_search(
117+
collection_query.data.query, catalog_url=catalog_url
118+
)
114119
return CollectionSearchResult(collections=collections)
115120
else:
116121
return None
@@ -278,24 +283,42 @@ async def item_search(ctx: Context) -> ItemSearchResult:
278283
results = await search_items_agent.run(
279284
f"Find items for the query: {ctx.query}", deps=ctx
280285
)
286+
catalog_url_to_use = ctx.catalog_url or STAC_CATALOG_URL
281287

282288
# determine the collections to search
283-
target_collections = await search_collections(ctx.query) or []
289+
target_collections = await search_collections(ctx.query, catalog_url_to_use) or []
284290
logger.info(f"Target collections: {pformat(target_collections)}")
285-
default_target_collections = DEFAULT_TARGET_COLLECTIONS
291+
292+
if not target_collections:
293+
# If no specific collections were found, use the default target collections
294+
default_target_collections = DEFAULT_TARGET_COLLECTIONS
295+
# check that default_target_collections exist in the catalog
296+
all_collection_ids = [
297+
collection.id
298+
for collection in Client.open(catalog_url_to_use).get_collections()
299+
]
300+
default_target_collections = [
301+
collection_id
302+
for collection_id in default_target_collections
303+
if collection_id in all_collection_ids
304+
]
305+
286306
if target_collections:
287307
explanation = "Considering the following collections:"
288308
for result in target_collections.collections:
289309
explanation += f"\n- {result.collection_id}: {result.explanation}"
290310
collections_to_search = [
291311
collection.collection_id for collection in target_collections.collections
292312
]
293-
else:
313+
elif default_target_collections:
294314
explanation = f"Including the following common collections in the search: {', '.join(default_target_collections)}\n"
295315
collections_to_search = default_target_collections
316+
else:
317+
explanation = "Searching all collections in the catalog."
318+
collections_to_search = all_collection_ids
296319

297320
# Actually perform the search
298-
client = Client.open(STAC_CATALOG_URL)
321+
client = Client.open(catalog_url_to_use)
299322
params = {
300323
"max_items": 20,
301324
"collections": collections_to_search,
@@ -310,11 +333,9 @@ async def item_search(ctx: Context) -> ItemSearchResult:
310333
logger.info(f"Found polygon for {results.data.location}")
311334
params["intersects"] = polygon
312335
else:
336+
explanation += f"\n\n No polygon found for {results.data.location}. "
313337
return ItemSearchResult(
314-
items=None,
315-
search_params=params,
316-
aoi=None,
317-
explanation=f"No polygon found for {results.data.location}",
338+
items=None, search_params=params, aoi=None, explanation=explanation
318339
)
319340

320341
if ctx.return_search_params_only:

stac_search/api.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
FastAPI server for STAC Natural Query
33
"""
44

5-
from fastapi import FastAPI
5+
from fastapi import FastAPI, HTTPException
66
from fastapi.middleware.cors import CORSMiddleware
77
from pydantic import BaseModel
8+
from typing import Optional
89
import uvicorn
910

1011
from stac_search.agents.collections_search import collection_search
@@ -30,29 +31,41 @@
3031
# Define request model
3132
class QueryRequest(BaseModel):
3233
query: str
34+
catalog_url: Optional[str] = None
3335

3436

3537
class STACItemsRequest(BaseModel):
3638
query: str
39+
catalog_url: Optional[str] = None
3740
return_search_params_only: bool = False
3841

3942

4043
# Define search endpoint
4144
@app.post("/search")
4245
async def search(request: QueryRequest):
4346
"""Search for STAC collections using natural language"""
44-
results = collection_search(request.query)
45-
return {"results": results}
47+
try:
48+
results = await collection_search(
49+
request.query, catalog_url=request.catalog_url
50+
)
51+
return {"results": results}
52+
except Exception as e:
53+
raise HTTPException(status_code=500, detail=str(e))
4654

4755

4856
@app.post("/items/search")
4957
async def search_items(request: STACItemsRequest):
5058
"""Search for STAC items using natural language"""
51-
ctx = ItemSearchContext(
52-
query=request.query, return_search_params_only=request.return_search_params_only
53-
)
54-
results = await item_search(ctx)
55-
return {"results": results}
59+
try:
60+
ctx = ItemSearchContext(
61+
query=request.query,
62+
catalog_url=request.catalog_url,
63+
return_search_params_only=request.return_search_params_only,
64+
)
65+
results = await item_search(ctx)
66+
return {"results": results}
67+
except Exception as e:
68+
raise HTTPException(status_code=500, detail=str(e))
5669

5770

5871
def start_server(host: str = "0.0.0.0", port: int = 8000):

0 commit comments

Comments
 (0)