Skip to content

Commit c3d545a

Browse files
add: reader level option for XarrayReader (#863)
* add: reader level option for XarrayReader * fix
1 parent a007b7b commit c3d545a

File tree

4 files changed

+108
-2
lines changed

4 files changed

+108
-2
lines changed

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11

22
# Unreleased
33

4+
# 9.0.0a5 (2026-02-13)
5+
6+
* add: reader's level options for XarrayReader
7+
48
# 9.0.0a4 (2026-02-11)
59

610
* fix: type hint for ImageData/PointData methods

rio_tiler/io/xarray.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import os
77
import warnings
8-
from typing import Any, cast
8+
from typing import Any, TypedDict, cast
99

1010
import attr
1111
import numpy
@@ -55,6 +55,12 @@
5555
MAX_ARRAY_SIZE = int(os.environ.get("RIO_TILER_MAX_ARRAY_SIZE", 1_000_000_000)) # 1Gb
5656

5757

58+
class Options(TypedDict, total=False):
59+
"""Reader Options."""
60+
61+
nodata: NoData | None
62+
63+
5864
@attr.s
5965
class XarrayReader(BaseReader):
6066
"""Xarray Reader.
@@ -81,8 +87,14 @@ class XarrayReader(BaseReader):
8187

8288
tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS)
8389

90+
options: Options = attr.ib()
91+
8492
_dims: list = attr.ib(init=False, factory=list)
8593

94+
@options.default
95+
def _options_default(self):
96+
return {}
97+
8698
def __attrs_post_init__(self):
8799
"""Set bounds and CRS."""
88100
assert xarray is not None, "xarray must be installed to use XarrayReader"
@@ -168,6 +180,10 @@ def info(self) -> Info:
168180
"""Return xarray.DataArray info."""
169181
metadata = [band.attrs for d in self._dims for band in self.input[d]] or [{}]
170182

183+
nodata_type = "None"
184+
if self.options.get("nodata", self.input.rio.nodata) is not None:
185+
nodata_type = "Nodata"
186+
171187
meta = {
172188
"bounds": self.bounds,
173189
"crs": CRS_to_uri(self.crs) or self.crs.to_wkt(),
@@ -177,7 +193,7 @@ def info(self) -> Info:
177193
for ix, v in enumerate(self.band_descriptions, 1)
178194
],
179195
"dtype": str(self.input.dtype),
180-
"nodata_type": "Nodata" if self.input.rio.nodata is not None else "None",
196+
"nodata_type": nodata_type,
181197
"name": self.input.name,
182198
"count": self.input.rio.count,
183199
"width": self.input.rio.width,
@@ -233,6 +249,8 @@ def statistics(
233249

234250
if nodata is not None:
235251
da = da.rio.write_nodata(nodata)
252+
elif (nodata := self.options.get("nodata")) is not None:
253+
da = da.rio.write_nodata(nodata)
236254

237255
data = da.to_masked_array()
238256
data.mask |= data.data == da.rio.nodata
@@ -353,6 +371,8 @@ def part( # noqa: C901
353371

354372
if nodata is not None:
355373
da = da.rio.write_nodata(nodata)
374+
elif (nodata := self.options.get("nodata")) is not None:
375+
da = da.rio.write_nodata(nodata)
356376

357377
# Forward valid_min/valid_max to the ImageData object
358378
minv, maxv = da.attrs.get("valid_min"), da.attrs.get("valid_max")
@@ -507,6 +527,8 @@ def preview( # noqa: C901
507527

508528
if nodata is not None:
509529
da = da.rio.write_nodata(nodata)
530+
elif (nodata := self.options.get("nodata")) is not None:
531+
da = da.rio.write_nodata(nodata)
510532

511533
if dst_crs and dst_crs != self.crs:
512534
src_width = da.rio.width
@@ -641,6 +663,8 @@ def point(
641663

642664
if nodata is not None:
643665
da = da.rio.write_nodata(nodata)
666+
elif (nodata := self.options.get("nodata")) is not None:
667+
da = da.rio.write_nodata(nodata)
644668

645669
y, x = rowcol(da.rio.transform(), da_lon, da_lat)
646670

tests/test_io_stac.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from rio_tiler.errors import InvalidAssetName, MissingAssets, TileOutsideBounds
2121
from rio_tiler.io import BaseReader, Reader, STACReader, XarrayReader
2222
from rio_tiler.io.stac import DEFAULT_VALID_TYPE
23+
from rio_tiler.io.xarray import Options
2324
from rio_tiler.models import BandStatistics
2425
from rio_tiler.types import AssetInfo
2526

@@ -885,8 +886,14 @@ class NetCDFReader(XarrayReader):
885886
ds: xarray.Dataset = attr.ib(init=False)
886887
input: xarray.DataArray = attr.ib(init=False)
887888

889+
options: Options = attr.ib()
890+
888891
_dims: List = attr.ib(init=False, factory=list)
889892

893+
@options.default
894+
def _options_default(self):
895+
return {}
896+
890897
def __attrs_post_init__(self):
891898
"""Set bounds and CRS."""
892899
self.ds = xarray.open_dataset(self.src_path, decode_coords="all")

tests/test_io_xarray.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,3 +1017,74 @@ def test_max_pixels():
10171017
# Should not raise error when using a small area
10181018
img = dst.tile(0, 0, 5)
10191019
assert img.array.shape == (3, 256, 256)
1020+
1021+
1022+
def test_xarray_reader_nodata_option():
1023+
"""test XarrayReader."""
1024+
# Create a 360/180 dataset that covers the whole world
1025+
arr = numpy.arange(0.0, 360 * 180).reshape(1, 180, 360)
1026+
arr[:, 0:50, 0:50] = 0 # we set the top-left corner to 0
1027+
1028+
data = xarray.DataArray(
1029+
arr,
1030+
dims=("time", "y", "x"),
1031+
coords={
1032+
"x": numpy.arange(-179.5, 180.5, 1),
1033+
"y": numpy.arange(89.5, -90.5, -1),
1034+
"time": [datetime(2022, 1, 1)],
1035+
},
1036+
)
1037+
1038+
data.attrs.update({"valid_min": arr.min(), "valid_max": arr.max()})
1039+
1040+
data.rio.write_crs("epsg:4326", inplace=True)
1041+
assert data.rio.nodata is None
1042+
with pytest.warns(
1043+
UserWarning,
1044+
match="Adjusting dataset latitudes to avoid re-projection overflow",
1045+
):
1046+
with XarrayReader(data, options={"nodata": 0}) as dst:
1047+
info = dst.info()
1048+
assert info.height == 180
1049+
assert info.width == 360
1050+
assert info.count == 1
1051+
1052+
img = dst.tile(0, 0, 1)
1053+
assert not img._mask.all() # not all the mask value are set to 255
1054+
assert img.array.mask[0, 0, 0] # the top left pixel should be masked
1055+
assert not img.array.mask[0, 100, 100] # pixel 100,100 shouldn't be masked
1056+
assert dst.input.rio.nodata is None
1057+
1058+
img = dst.part((-160, -80, 160, 80))
1059+
assert not img._mask.all() # not all the mask value are set to 255
1060+
assert img.array.mask[0, 0, 0] # the top left pixel should be masked
1061+
assert not img.array.mask[0, 100, 100] # pixel 100,100 shouldn't be masked
1062+
1063+
# overwrite the nodata value to 0
1064+
pt = dst.point(-179, 89)
1065+
assert pt.count == 1
1066+
assert not pt._mask[0]
1067+
assert dst.input.rio.nodata is None
1068+
1069+
feat = {
1070+
"type": "Feature",
1071+
"geometry": {
1072+
"type": "Polygon",
1073+
"coordinates": [
1074+
[
1075+
[-180.0, 0],
1076+
[-180.0, 85.0511287798066],
1077+
[0, 85.0511287798066],
1078+
[0, 6.023673383202919e-13],
1079+
[-180.0, 0],
1080+
]
1081+
],
1082+
},
1083+
"properties": {"title": "XYZ tile (0, 0, 1)"},
1084+
}
1085+
1086+
img = dst.feature(feat)
1087+
assert not img._mask.all() # not all the mask value are set to 255
1088+
assert img.array.mask[0, 0, 0] # the top left pixel should be masked
1089+
assert not img.array.mask[0, 50, 100] # pixel 50,100 shouldn't be masked
1090+
assert dst.input.rio.nodata is None

0 commit comments

Comments
 (0)