diff --git a/tests/fixtures/catalog.json b/tests/fixtures/catalog.json index 292a6e6..6d3c3a4 100644 --- a/tests/fixtures/catalog.json +++ b/tests/fixtures/catalog.json @@ -321,6 +321,23 @@ "rel": "self", "type": "application/json", "href": "https://stac.endpoint.io/collections" + }, + { + "rel": "data", + "type": "application/json", + "href": "https://stac.endpoint.io/collections" + }, + { + "rel": "aggregate", + "type": "application/json", + "title": "Aggregate", + "href": "https://stac.endpoint.io/aggregate" + }, + { + "rel": "aggregations", + "type": "application/json", + "title": "Aggregations", + "href": "https://stac.endpoint.io/aggregations" } ] } diff --git a/tests/test_advanced_pystac_client.py b/tests/test_advanced_pystac_client.py new file mode 100644 index 0000000..d2dc74b --- /dev/null +++ b/tests/test_advanced_pystac_client.py @@ -0,0 +1,119 @@ +"""Test Advanced PySTAC client.""" +import json +import os +from unittest.mock import MagicMock, patch + +import pytest + +from titiler.stacapi.pystac import Client + +catalog_json = os.path.join(os.path.dirname(__file__), "fixtures", "catalog.json") + + +@pytest.fixture +def mock_stac_io(): + """STAC IO mock""" + return MagicMock() + + +@pytest.fixture +def client(mock_stac_io): + """STAC client mock""" + client = Client(id="pystac-client", description="pystac-client") + + with open(catalog_json, "r") as f: + catalog = json.loads(f.read()) + client.open = MagicMock() + client.open.return_value = catalog + client._collections_href = MagicMock() + client._collections_href.return_value = "http://example.com/collections" + + client._stac_io = mock_stac_io + return client + + +def test_get_supported_aggregations(client, mock_stac_io): + """Test supported STAC aggregation methods""" + mock_stac_io.read_json.return_value = { + "aggregations": [{"name": "aggregation1"}, {"name": "aggregation2"}] + } + supported_aggregations = client.get_supported_aggregations() + assert supported_aggregations == ["aggregation1", "aggregation2"] + + +@patch( + "titiler.stacapi.pystac.advanced_client.Client.get_supported_aggregations", + return_value=["datetime_frequency"], +) +def test_get_aggregation_unsupported(supported_aggregations, client): + """Test handling of unsupported aggregation types""" + collection_id = "sentinel-2-l2a" + aggregation = "unsupported-aggregation" + + with pytest.warns( + UserWarning, match="Aggregation type unsupported-aggregation is not supported" + ): + aggregation_data = client.get_aggregation(collection_id, aggregation) + assert aggregation_data == [] + + +@patch( + "titiler.stacapi.pystac.advanced_client.Client.get_supported_aggregations", + return_value=["datetime_frequency"], +) +def test_get_aggregation(supported_aggregations, client, mock_stac_io): + """Test handling aggregation response""" + collection_id = "sentinel-2-l2a" + aggregation = "datetime_frequency" + aggregation_params = {"datetime_frequency_interval": "day"} + + mock_stac_io.read_json.return_value = { + "aggregations": [ + { + "name": "datetime_frequency", + "buckets": [ + { + "key": "2023-12-11T00:00:00.000Z", + "data_type": "frequency_distribution", + "frequency": 1, + "to": None, + "from": None, + } + ], + }, + { + "name": "unusable_aggregation", + "buckets": [ + { + "key": "2023-12-11T00:00:00.000Z", + } + ], + }, + ] + } + + aggregation_data = client.get_aggregation( + collection_id, aggregation, aggregation_params + ) + assert aggregation_data[0]["key"] == "2023-12-11T00:00:00.000Z" + assert aggregation_data[0]["data_type"] == "frequency_distribution" + assert aggregation_data[0]["frequency"] == 1 + assert len(aggregation_data) == 1 + + +@patch( + "titiler.stacapi.pystac.advanced_client.Client.get_supported_aggregations", + return_value=["datetime_frequency"], +) +def test_get_aggregation_no_response(supported_aggregations, client, mock_stac_io): + """Test handling of no aggregation response""" + collection_id = "sentinel-2-l2a" + aggregation = "datetime_frequency" + aggregation_params = {"datetime_frequency_interval": "day"} + + mock_stac_io.read_json.return_value = [] + + aggregation_data = client.get_aggregation( + collection_id, aggregation, aggregation_params + ) + assert aggregation_data == [] diff --git a/titiler/stacapi/factory.py b/titiler/stacapi/factory.py index 854c25f..e97bbc1 100644 --- a/titiler/stacapi/factory.py +++ b/titiler/stacapi/factory.py @@ -19,7 +19,6 @@ from morecantile import tms as morecantile_tms from morecantile.defaults import TileMatrixSets from pydantic import conint -from pystac_client import Client from pystac_client.stac_api_io import StacApiIO from rasterio.transform import xy as rowcol_to_coords from rasterio.warp import transform as transform_points @@ -48,6 +47,7 @@ from titiler.stacapi.backend import STACAPIBackend from titiler.stacapi.dependencies import APIParams, STACApiParams, STACSearchParams from titiler.stacapi.models import FeatureInfo, LayerDict +from titiler.stacapi.pystac import Client from titiler.stacapi.settings import CacheSettings, RetrySettings from titiler.stacapi.utils import _tms_limits @@ -580,6 +580,7 @@ def get_layer_from_collections( # noqa: C901 tilematrixsets = render.pop("tilematrixsets", None) output_format = render.pop("format", None) + aggregation = render.pop("aggregation", None) _ = render.pop("minmax_zoom", None) # Not Used _ = render.pop("title", None) # Not Used @@ -643,6 +644,20 @@ def get_layer_from_collections( # noqa: C901 "values" ] ] + elif aggregation and aggregation["name"] == "datetime_frequency": + datetime_aggregation = catalog.get_aggregation( + collection_id=collection.id, + aggregation="datetime_frequency", + aggregation_params=aggregation["params"], + ) + layer["time"] = [ + python_datetime.datetime.strptime( + t["key"], + "%Y-%m-%dT%H:%M:%S.000Z", + ).strftime("%Y-%m-%d") + for t in datetime_aggregation + if t["frequency"] > 0 + ] elif intervals := temporal_extent.intervals: start_date = intervals[0][0] end_date = ( diff --git a/titiler/stacapi/pystac/__init__.py b/titiler/stacapi/pystac/__init__.py new file mode 100644 index 0000000..3733c66 --- /dev/null +++ b/titiler/stacapi/pystac/__init__.py @@ -0,0 +1,7 @@ +"""titiler.pystac""" + +__all__ = [ + "Client", +] + +from .advanced_client import Client diff --git a/titiler/stacapi/pystac/advanced_client.py b/titiler/stacapi/pystac/advanced_client.py new file mode 100644 index 0000000..f02bd26 --- /dev/null +++ b/titiler/stacapi/pystac/advanced_client.py @@ -0,0 +1,85 @@ +""" +This module provides an advanced client for interacting with STAC (SpatioTemporal Asset Catalog) APIs. + +The `Client` class extends the basic functionality of the `pystac.Client` to include +methods for retrieving and aggregating data from STAC collections. +""" + +import warnings +from typing import Dict, List, Optional +from urllib.parse import urlencode + +import pystac +import pystac_client + + +class Client(pystac_client.Client): + """Client extends the basic functionality of the pystac.Client class.""" + + def get_aggregation( + self, + collection_id: str, + aggregation: str, + aggregation_params: Optional[Dict] = None, + ) -> List[Dict]: + """Perform an aggregation on a STAC collection. + + Args: + collection_id (str): The ID of the collection to aggregate. + aggregation (str): The aggregation type to perform. + aggregation_params (Optional[dict], optional): Additional parameters for the aggregation. Defaults to None. + Returns: + List[str]: The aggregation response. + """ + if aggregation not in self.get_supported_aggregations(): + warnings.warn( + f"Aggregation type {aggregation} is not supported", stacklevel=1 + ) + return [] + + # Construct the URL for aggregation + url = ( + self._collections_href(collection_id) + + f"/aggregate?aggregations={aggregation}" + ) + if aggregation_params: + params = urlencode(aggregation_params) + url += f"&{params}" + + aggregation_response = self._stac_io.read_json(url) + + if not aggregation_response: + return [] + + aggregation_data = [] + for agg in aggregation_response["aggregations"]: + if agg["name"] == aggregation: + aggregation_data = agg["buckets"] + + return aggregation_data + + def get_supported_aggregations(self) -> List[str]: + """Get the supported aggregation types. + + Returns: + List[str]: The supported aggregations. + """ + response = self._stac_io.read_json(self.get_aggregations_link()) + aggregations = response.get("aggregations", []) + return [agg["name"] for agg in aggregations] + + def get_aggregations_link(self) -> Optional[pystac.Link]: + """Returns this client's aggregations link. + + Returns: + Optional[pystac.Link]: The aggregations link, or None if there is not one found. + """ + return next( + ( + link + for link in self.links + if link.rel == "aggregations" + and link.media_type == pystac.MediaType.JSON + ), + None, + )