Skip to content

Commit ec593fd

Browse files
committed
add stac+xarray notebook
1 parent e14e04e commit ec593fd

File tree

2 files changed

+379
-0
lines changed

2 files changed

+379
-0
lines changed

docs/mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ nav:
8080
- Custom Sentinel 2 Tiler: "examples/code/tiler_for_sentinel2.md"
8181
- Add custom algorithms: "examples/code/tiler_with_custom_algorithm.md"
8282
- GDAL WMTS Extension: "examples/code/create_gdal_wmts_extension.md"
83+
- STAC + Xarray: "examples/code/tiler_with_custom_stac+xarray.md"
8384

8485
- Use TiTiler endpoints:
8586
- COG: "examples/notebooks/Working_with_CloudOptimizedGeoTIFF_simple.ipynb"
Lines changed: 378 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
2+
**Goal**: Create a custom STAC Reader supporting both COG and NetCDF/Zarr dataset
3+
4+
**requirements**:
5+
- `titiler.core`
6+
- `titiler.xarray`
7+
- `fsspec`
8+
- `zarr`
9+
- `h5netcdf`
10+
- `aiohttp` (optional)
11+
- `s3fs` (optional)
12+
13+
**links**:
14+
15+
- https://cogeotiff.github.io/rio-tiler/examples/STAC_datacube_support/
16+
17+
18+
#### 1. Custom STACReader
19+
20+
First, we need to create a custom `STACReader` which will support both COG and NetCDF/Zarr dataset. The custom parts will be:
21+
22+
- add `netcdf` and `zarr` as valid asset media types
23+
- introduce a new `md://` prefixed asset form, so users can pass `assets=md://{netcdf asset name}?variable={variable name}` as we do for the `GDAL vrt string connection` support.
24+
25+
```python title="stac.py"
26+
from typing import Set, Type, Tuple, Dict, Optional
27+
28+
import attr
29+
from urllib.parse import urlparse, parse_qsl
30+
from rio_tiler.types import AssetInfo
31+
from rio_tiler.io import BaseReader, Reader
32+
from rio_tiler.io import stac
33+
34+
from titiler.xarray.io import Reader as XarrayReader
35+
36+
valid_types = {
37+
*stac.DEFAULT_VALID_TYPE,
38+
"application/x-netcdf",
39+
"application/vnd+zarr",
40+
}
41+
42+
43+
@attr.s
44+
class STACReader(stac.STACReader):
45+
"""Custom STACReader which adds support for `md://` prefixed assets.
46+
47+
Example:
48+
>>> with STACReader("https://raw.githubusercontent.com/cogeotiff/rio-tiler/refs/heads/main/tests/fixtures/stac_netcdf.json") as src:
49+
print(src.assets)
50+
print(src._get_asset_info("md://netcdf?variable=dataset"))
51+
52+
['geotiff', 'netcdf']
53+
{'url': 'https://raw.githubusercontent.com/cogeotiff/rio-tiler/refs/heads/main/tests/fixtures/dataset_2d.nc', 'metadata': {}, 'reader_options': {'variable': 'dataset'}, 'media_type': 'application/x-netcdf'}
54+
55+
"""
56+
include_asset_types: Set[str] = attr.ib(default=valid_types)
57+
58+
def _get_reader(self, asset_info: AssetInfo) -> Tuple[Type[BaseReader], Dict]:
59+
"""Get Asset Reader."""
60+
asset_type = asset_info.get("media_type", None)
61+
if asset_type and asset_type in [
62+
"application/x-netcdf",
63+
"application/vnd+zarr",
64+
"application/x-hdf5",
65+
"application/x-hdf",
66+
]:
67+
return XarrayReader, asset_info.get("reader_options", {})
68+
69+
return Reader, asset_info.get("reader_options", {})
70+
71+
def _parse_md_asset(self, asset: str) -> Tuple[str, Optional[Dict]]:
72+
"""Parse md:// asset string and return both asset name and reader options"""
73+
if asset.startswith("md://") and asset not in self.assets:
74+
parsed = urlparse(asset)
75+
if not parsed.netloc or parsed.netloc not in self.assets:
76+
raise InvalidAssetName(
77+
f"'{parsed.netloc}' is not valid, should be one of {self.assets}"
78+
)
79+
80+
# NOTE: by using `parse_qsl` we assume the
81+
# reader_options are in form of `key=single_value`
82+
# reader_options for XarrayReader are:
83+
# - variable: str
84+
# - group: Optional[str]
85+
# - decode_times: bool = True
86+
# - datetime: Optional[str]
87+
# - drop_dim: Optional[str]
88+
return parsed.netloc, dict(parse_qsl(parsed.query))
89+
90+
return asset, None
91+
92+
def _get_asset_info(self, asset: str) -> AssetInfo:
93+
"""Validate asset names and return asset's info.
94+
95+
Args:
96+
asset (str): STAC asset name.
97+
98+
Returns:
99+
AssetInfo: STAC asset info.
100+
101+
"""
102+
vrt_options = None
103+
reader_options = None
104+
if asset.startswith("vrt://"):
105+
asset, vrt_options = self._parse_vrt_asset(asset)
106+
107+
# not part of the original STACReader
108+
elif asset.startswith("md://"):
109+
asset, reader_options = self._parse_md_asset(asset)
110+
111+
if asset not in self.assets:
112+
raise InvalidAssetName(
113+
f"'{asset}' is not valid, should be one of {self.assets}"
114+
)
115+
116+
asset_info = self.item.assets[asset]
117+
extras = asset_info.extra_fields
118+
119+
info = AssetInfo(
120+
url=asset_info.get_absolute_href() or asset_info.href,
121+
metadata=extras if not vrt_options else None,
122+
reader_options=reader_options or {}
123+
)
124+
125+
if stac.STAC_ALTERNATE_KEY and extras.get("alternate"):
126+
if alternate := extras["alternate"].get(stac.STAC_ALTERNATE_KEY):
127+
info["url"] = alternate["href"]
128+
129+
if asset_info.media_type:
130+
info["media_type"] = asset_info.media_type
131+
132+
# https://github.com/stac-extensions/file
133+
if head := extras.get("file:header_size"):
134+
info["env"] = {"GDAL_INGESTED_BYTES_AT_OPEN": head}
135+
136+
# https://github.com/stac-extensions/raster
137+
if extras.get("raster:bands") and not vrt_options:
138+
bands = extras.get("raster:bands")
139+
stats = [
140+
(b["statistics"]["minimum"], b["statistics"]["maximum"])
141+
for b in bands
142+
if {"minimum", "maximum"}.issubset(b.get("statistics", {}))
143+
]
144+
# check that stats data are all double and make warning if not
145+
if (
146+
stats
147+
and all(isinstance(v, (int, float)) for stat in stats for v in stat)
148+
and len(stats) == len(bands)
149+
):
150+
info["dataset_statistics"] = stats
151+
else:
152+
warnings.warn(
153+
"Some statistics data in STAC are invalid, they will be ignored."
154+
)
155+
156+
if vrt_options:
157+
info["url"] = f"vrt://{info['url']}?{vrt_options}"
158+
159+
return info
160+
```
161+
162+
#### 2. Endpoint Factory
163+
164+
Custom `MultiBaseTilerFactory` which removes some endpoints (`/preview`) and adapt dependencies to work with both COG and Xarray Datasets.
165+
166+
```python title="factory.py"
167+
"""Custom MultiBaseTilerFactory."""
168+
from dataclasses import dataclass
169+
from typing import Type, Union, Optional, List
170+
from typing_extensions import Annotated
171+
from attrs import define, field
172+
from geojson_pydantic.features import Feature, FeatureCollection
173+
from fastapi import Body, Depends, Query
174+
from titiler.core import factory
175+
from titiler.core.dependencies import (
176+
DefaultDependency,
177+
BidxParams,
178+
AssetsParams,
179+
AssetsBidxExprParamsOptional,
180+
CoordCRSParams,
181+
DstCRSParams,
182+
)
183+
from titiler.core.models.responses import MultiBaseStatisticsGeoJSON
184+
from titiler.core.resources.responses import GeoJSONResponse
185+
from rio_tiler.constants import WGS84_CRS
186+
from rio_tiler.io import MultiBaseReader
187+
188+
from stac import STACReader
189+
190+
191+
# Simple Asset dependency (1 asset, no expression)
192+
@dataclass
193+
class SingleAssetsParams(DefaultDependency):
194+
"""Custom Assets parameters which only accept ONE asset and make it required."""
195+
196+
assets: Annotated[
197+
str,
198+
Query(title="Asset names", description="Asset's name."),
199+
]
200+
201+
indexes: Annotated[
202+
Optional[List[int]],
203+
Query(
204+
title="Band indexes",
205+
alias="bidx",
206+
description="Dataset band indexes",
207+
openapi_examples={
208+
"one-band": {"value": [1]},
209+
"multi-bands": {"value": [1, 2, 3]},
210+
},
211+
),
212+
] = None
213+
214+
215+
@define(kw_only=True)
216+
class MultiBaseTilerFactory(factory.MultiBaseTilerFactory):
217+
218+
reader: Type[MultiBaseReader] = STACReader
219+
220+
# Assets/Indexes/Expression dependency
221+
layer_dependency: Type[DefaultDependency] = SingleAssetsParams
222+
223+
# Assets dependency (for /info endpoints)
224+
assets_dependency: Type[DefaultDependency] = AssetsParams
225+
226+
# remove preview endpoints
227+
img_preview_dependency: Type[DefaultDependency] = field(init=False)
228+
add_preview: bool = field(init=False, default=False)
229+
230+
# Overwrite the `/statistics` endpoint to remove `full` dataset statistics (which could be unusable for NetCDF dataset)
231+
def statistics(self): # noqa: C901
232+
"""Register /statistics endpoint."""
233+
234+
@self.router.post(
235+
"/statistics",
236+
response_model=MultiBaseStatisticsGeoJSON,
237+
response_model_exclude_none=True,
238+
response_class=GeoJSONResponse,
239+
responses={
240+
200: {
241+
"content": {"application/geo+json": {}},
242+
"description": "Return dataset's statistics from feature or featureCollection.",
243+
}
244+
},
245+
)
246+
def geojson_statistics(
247+
geojson: Annotated[
248+
Union[FeatureCollection, Feature],
249+
Body(description="GeoJSON Feature or FeatureCollection."),
250+
],
251+
src_path=Depends(self.path_dependency),
252+
reader_params=Depends(self.reader_dependency),
253+
layer_params=Depends(AssetsBidxExprParamsOptional),
254+
dataset_params=Depends(self.dataset_dependency),
255+
coord_crs=Depends(CoordCRSParams),
256+
dst_crs=Depends(DstCRSParams),
257+
post_process=Depends(self.process_dependency),
258+
image_params=Depends(self.img_part_dependency),
259+
stats_params=Depends(self.stats_dependency),
260+
histogram_params=Depends(self.histogram_dependency),
261+
env=Depends(self.environment_dependency),
262+
):
263+
"""Get Statistics from a geojson feature or featureCollection."""
264+
fc = geojson
265+
if isinstance(fc, Feature):
266+
fc = FeatureCollection(type="FeatureCollection", features=[geojson])
267+
268+
with rasterio.Env(**env):
269+
with self.reader(src_path, **reader_params.as_dict()) as src_dst:
270+
# Default to all available assets
271+
if not layer_params.assets and not layer_params.expression:
272+
layer_params.assets = src_dst.assets
273+
274+
for feature in fc:
275+
image = src_dst.feature(
276+
feature.model_dump(exclude_none=True),
277+
shape_crs=coord_crs or WGS84_CRS,
278+
dst_crs=dst_crs,
279+
align_bounds_with_dataset=True,
280+
**layer_params.as_dict(),
281+
**image_params.as_dict(),
282+
**dataset_params.as_dict(),
283+
)
284+
285+
if post_process:
286+
image = post_process(image)
287+
288+
stats = image.statistics(
289+
**stats_params.as_dict(),
290+
hist_options=histogram_params.as_dict(),
291+
)
292+
293+
feature.properties = feature.properties or {}
294+
# NOTE: because we use `src_dst.feature` the statistics will be in form of
295+
# `Dict[str, BandStatistics]` and not `Dict[str, Dict[str, BandStatistics]]`
296+
feature.properties.update({"statistics": stats})
297+
298+
return fc.features[0] if isinstance(geojson, Feature) else fc
299+
```
300+
301+
#### 3. Application
302+
303+
```python title="main.py"
304+
"""FastAPI application."""
305+
306+
from fastapi import FastAPI
307+
308+
from titiler.core.dependencies import DatasetPathParams
309+
from titiler.core.errors import DEFAULT_STATUS_CODES, add_exception_handlers
310+
311+
from factory import MultiBaseTilerFactory
312+
313+
# STAC uses MultiBaseReader so we use MultiBaseTilerFactory to built the default endpoints
314+
stac = MultiBaseTilerFactory(router_prefix="stac")
315+
316+
# Create FastAPI application
317+
app = FastAPI()
318+
app.include_router(stac.router, tags=["STAC"])
319+
add_exception_handlers(app, DEFAULT_STATUS_CODES)
320+
```
321+
322+
```
323+
uvicorn app:app --port 8080 --reload
324+
```
325+
326+
<img width="800" alt="Screenshot 2024-11-07 at 4 42 21 PM" src="https://github.com/user-attachments/assets/2b68500e-c1a7-4461-90bd-67bb492e6057">
327+
328+
329+
```bash
330+
curl http://127.0.0.1:8080/assets\?url\=https%3A%2F%2Fraw.githubusercontent.com%2Fcogeotiff%2Frio-tiler%2Frefs%2Fheads%2Fmain%2Ftests%2Ffixtures%2Fstac_netcdf.json | jq
331+
332+
[
333+
"geotiff",
334+
"netcdf"
335+
]
336+
```
337+
338+
```bash
339+
curl http://127.0.0.1:8080/info?url=https://raw.githubusercontent.com/cogeotiff/rio-tiler/refs/heads/main/tests/fixtures/stac_netcdf.json&assets=md://netcdf?variable=dataset | jq
340+
{
341+
"md://netcdf?variable=dataset": {
342+
"bounds": [
343+
-170.085,
344+
-80.08,
345+
169.914999999975,
346+
79.91999999999659
347+
],
348+
"crs": "http://www.opengis.net/def/crs/EPSG/0/4326",
349+
"band_metadata": [
350+
[
351+
"b1",
352+
{}
353+
]
354+
],
355+
"band_descriptions": [
356+
[
357+
"b1",
358+
"value"
359+
]
360+
],
361+
"dtype": "float64",
362+
"nodata_type": "Nodata",
363+
"name": "dataset",
364+
"count": 1,
365+
"width": 2000,
366+
"height": 1000,
367+
"attrs": {
368+
"valid_min": 1.0,
369+
"valid_max": 1000.0,
370+
"fill_value": 0
371+
}
372+
}
373+
}
374+
```
375+
376+
```bash
377+
curl http://127.0.0.1:8080/tiles/WebMercatorQuad/1/0/0?url=https://raw.githubusercontent.com/cogeotiff/rio-tiler/refs/heads/main/tests/fixtures/stac_netcdf.json&assets=md://netcdf?variable=dataset&rescale=0,1000
378+
```

0 commit comments

Comments
 (0)