Skip to content

Commit 51af1fe

Browse files
committed
feat: extended pystac client to support aggregations stac-api extension calls [https://github.com/stac-api-extensions/aggregation]
1 parent 7214394 commit 51af1fe

File tree

5 files changed

+247
-2
lines changed

5 files changed

+247
-2
lines changed

tests/fixtures/catalog.json

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,23 @@
321321
"rel": "self",
322322
"type": "application/json",
323323
"href": "https://stac.endpoint.io/collections"
324+
},
325+
{
326+
"rel": "data",
327+
"type": "application/json",
328+
"href": "https://stac.endpoint.io/collections"
329+
},
330+
{
331+
"rel": "aggregate",
332+
"type": "application/json",
333+
"title": "Aggregate",
334+
"href": "https://stac.endpoint.io/aggregate"
335+
},
336+
{
337+
"rel": "aggregations",
338+
"type": "application/json",
339+
"title": "Aggregations",
340+
"href": "https://stac.endpoint.io/aggregations"
324341
}
325342
]
326343
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Test Advanced PySTAC client."""
2+
import json
3+
import os
4+
from unittest.mock import MagicMock, patch
5+
6+
import pytest
7+
8+
from titiler.pystac import AdvancedClient
9+
10+
catalog_json = os.path.join(os.path.dirname(__file__), "fixtures", "catalog.json")
11+
12+
13+
@pytest.fixture
14+
def mock_stac_io():
15+
"""STAC IO mock"""
16+
return MagicMock()
17+
18+
19+
@pytest.fixture
20+
def client(mock_stac_io):
21+
"""STAC client mock"""
22+
client = AdvancedClient(id="pystac-client", description="pystac-client")
23+
24+
with open(catalog_json, "r") as f:
25+
catalog = json.loads(f.read())
26+
client.open = MagicMock()
27+
client.open.return_value = catalog
28+
client._collections_href = MagicMock()
29+
client._collections_href.return_value = "http://example.com/collections"
30+
31+
client._stac_io = mock_stac_io
32+
return client
33+
34+
35+
def test_get_supported_aggregations(client, mock_stac_io):
36+
"""Test supported STAC aggregation methods"""
37+
mock_stac_io.read_json.return_value = {
38+
"aggregations": [{"name": "aggregation1"}, {"name": "aggregation2"}]
39+
}
40+
supported_aggregations = client.get_supported_aggregations()
41+
assert supported_aggregations == ["aggregation1", "aggregation2"]
42+
43+
44+
@patch(
45+
"titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations",
46+
return_value=["datetime_frequency"],
47+
)
48+
def test_get_aggregation_unsupported(supported_aggregations, client):
49+
"""Test handling of unsupported aggregation types"""
50+
collection_id = "sentinel-2-l2a"
51+
aggregation = "unsupported-aggregation"
52+
53+
with pytest.warns(
54+
UserWarning, match="Aggregation type unsupported-aggregation is not supported"
55+
):
56+
aggregation_data = client.get_aggregation(collection_id, aggregation)
57+
assert aggregation_data == []
58+
59+
60+
@patch(
61+
"titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations",
62+
return_value=["datetime_frequency"],
63+
)
64+
def test_get_aggregation(supported_aggregations, client, mock_stac_io):
65+
"""Test handling aggregation response"""
66+
collection_id = "sentinel-2-l2a"
67+
aggregation = "datetime_frequency"
68+
aggregation_params = {"datetime_frequency_interval": "day"}
69+
70+
mock_stac_io.read_json.return_value = {
71+
"aggregations": [
72+
{
73+
"name": "datetime_frequency",
74+
"buckets": [
75+
{
76+
"key": "2023-12-11T00:00:00.000Z",
77+
"data_type": "frequency_distribution",
78+
"frequency": 1,
79+
"to": None,
80+
"from": None,
81+
}
82+
],
83+
},
84+
{
85+
"name": "unusable_aggregation",
86+
"buckets": [
87+
{
88+
"key": "2023-12-11T00:00:00.000Z",
89+
}
90+
],
91+
},
92+
]
93+
}
94+
95+
aggregation_data = client.get_aggregation(
96+
collection_id, aggregation, aggregation_params
97+
)
98+
assert aggregation_data[0]["key"] == "2023-12-11T00:00:00.000Z"
99+
assert aggregation_data[0]["data_type"] == "frequency_distribution"
100+
assert aggregation_data[0]["frequency"] == 1
101+
assert len(aggregation_data) == 1
102+
103+
104+
@patch(
105+
"titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations",
106+
return_value=["datetime_frequency"],
107+
)
108+
def test_get_aggregation_no_response(supported_aggregations, client, mock_stac_io):
109+
"""Test handling of no aggregation response"""
110+
collection_id = "sentinel-2-l2a"
111+
aggregation = "datetime_frequency"
112+
aggregation_params = {"datetime_frequency_interval": "day"}
113+
114+
mock_stac_io.read_json.return_value = []
115+
116+
aggregation_data = client.get_aggregation(
117+
collection_id, aggregation, aggregation_params
118+
)
119+
assert aggregation_data == []

titiler/pystac/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""titiler.pystac"""
2+
3+
__all__ = [
4+
"AdvancedClient",
5+
]
6+
7+
from titiler.pystac.advanced_client import AdvancedClient

titiler/pystac/advanced_client.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""
2+
This module provides an advanced client for interacting with STAC (SpatioTemporal Asset Catalog) APIs.
3+
4+
The `AdvancedClient` class extends the basic functionality of the `pystac.Client` to include
5+
methods for retrieving and aggregating data from STAC collections.
6+
"""
7+
8+
import warnings
9+
from typing import Optional
10+
from urllib.parse import urlencode
11+
12+
import pystac
13+
from pystac_client import Client
14+
15+
16+
class AdvancedClient(Client):
17+
"""AdvancedClient extends the basic functionality of the pystac.Client class."""
18+
19+
def get_aggregation(
20+
self,
21+
collection_id: str,
22+
aggregation: str,
23+
aggregation_params: Optional[dict] = None,
24+
) -> list[dict]:
25+
"""Perform an aggregation on a STAC collection.
26+
27+
Args:
28+
collection_id (str): The ID of the collection to aggregate.
29+
aggregation (str): The aggregation type to perform.
30+
aggregation_params (Optional[dict], optional): Additional parameters for the aggregation. Defaults to None.
31+
Returns:
32+
List[str]: The aggregation response.
33+
"""
34+
assert self._stac_io is not None
35+
36+
if aggregation not in self.get_supported_aggregations():
37+
warnings.warn(
38+
f"Aggregation type {aggregation} is not supported", stacklevel=1
39+
)
40+
return []
41+
42+
# Construct the URL for aggregation
43+
url = (
44+
self._collections_href(collection_id)
45+
+ f"/aggregate?aggregations={aggregation}"
46+
)
47+
if aggregation_params:
48+
params = urlencode(aggregation_params)
49+
url += f"&{params}"
50+
51+
aggregation_response = self._stac_io.read_json(url)
52+
53+
if not aggregation_response:
54+
return []
55+
56+
aggregation_data = []
57+
for agg in aggregation_response["aggregations"]:
58+
if agg["name"] == aggregation:
59+
aggregation_data = agg["buckets"]
60+
61+
return aggregation_data
62+
63+
def get_supported_aggregations(self) -> list[str]:
64+
"""Get the supported aggregation types.
65+
66+
Returns:
67+
List[str]: The supported aggregations.
68+
"""
69+
response = self._stac_io.read_json(self.get_aggregations_link())
70+
aggregations = response.get("aggregations", [])
71+
return [agg["name"] for agg in aggregations]
72+
73+
def get_aggregations_link(self) -> Optional[pystac.Link]:
74+
"""Returns this client's aggregations link.
75+
76+
Returns:
77+
Optional[pystac.Link]: The aggregations link, or None if there is not one found.
78+
"""
79+
return next(
80+
(
81+
link
82+
for link in self.links
83+
if link.rel == "aggregations"
84+
and link.media_type == pystac.MediaType.JSON
85+
),
86+
None,
87+
)

titiler/stacapi/factory.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from morecantile import tms as morecantile_tms
2020
from morecantile.defaults import TileMatrixSets
2121
from pydantic import conint
22-
from pystac_client import Client
2322
from pystac_client.stac_api_io import StacApiIO
2423
from rasterio.transform import xy as rowcol_to_coords
2524
from rasterio.warp import transform as transform_points
@@ -45,6 +44,7 @@
4544
from titiler.core.resources.responses import GeoJSONResponse, XMLResponse
4645
from titiler.core.utils import render_image
4746
from titiler.mosaic.factory import PixelSelectionParams
47+
from titiler.pystac import AdvancedClient
4848
from titiler.stacapi.backend import STACAPIBackend
4949
from titiler.stacapi.dependencies import APIParams, STACApiParams, STACSearchParams
5050
from titiler.stacapi.models import FeatureInfo, LayerDict
@@ -568,7 +568,7 @@ def get_layer_from_collections( # noqa: C901
568568
),
569569
headers=headers,
570570
)
571-
catalog = Client.open(url, stac_io=stac_api_io)
571+
catalog = AdvancedClient.open(url, stac_io=stac_api_io)
572572

573573
layers: Dict[str, LayerDict] = {}
574574
for collection in catalog.get_collections():
@@ -580,6 +580,7 @@ def get_layer_from_collections( # noqa: C901
580580

581581
tilematrixsets = render.pop("tilematrixsets", None)
582582
output_format = render.pop("format", None)
583+
aggregation = render.pop("aggregation", None)
583584

584585
_ = render.pop("minmax_zoom", None) # Not Used
585586
_ = render.pop("title", None) # Not Used
@@ -643,6 +644,20 @@ def get_layer_from_collections( # noqa: C901
643644
"values"
644645
]
645646
]
647+
elif aggregation and aggregation["name"] == "datetime_frequency":
648+
datetime_aggregation = catalog.get_aggregation(
649+
collection_id=collection.id,
650+
aggregation="datetime_frequency",
651+
aggregation_params=aggregation["params"],
652+
)
653+
layer["time"] = [
654+
python_datetime.datetime.strptime(
655+
t["key"],
656+
"%Y-%m-%dT%H:%M:%S.000Z",
657+
).strftime("%Y-%m-%d")
658+
for t in datetime_aggregation
659+
if t["frequency"] > 0
660+
]
646661
elif intervals := temporal_extent.intervals:
647662
start_date = intervals[0][0]
648663
end_date = (

0 commit comments

Comments
 (0)