Skip to content

Commit 736f58e

Browse files
authored
feat: Support loading transparency masks (#36)
* feat: Support loading transparency masks * improve docstring
1 parent 98ae83b commit 736f58e

File tree

7 files changed

+112
-35
lines changed

7 files changed

+112
-35
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
keywords = ["geotiff", "tiff", "async", "cog", "raster", "gis"]
2323
dependencies = [
2424
"affine>=2.4.0",
25-
"async-tiff>=0.4.0",
25+
"async-tiff>=0.5.0-beta.1",
2626
"numpy>=2.0",
2727
"pyproj>=3.3.0",
2828
]

src/async_geotiff/_array.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@ class Array(TransformMixin):
2222
data: NDArray
2323
"""The array data with shape (bands, height, width)."""
2424

25-
mask: NDArray | None
26-
"""The mask array with shape (height, width), if any."""
25+
mask: NDArray[np.bool_] | None
26+
"""The mask array with shape (height, width), if any.
27+
28+
Values of True indicate valid data; False indicates no data.
29+
"""
2730

2831
width: int
2932
"""The width of the array in pixels."""
@@ -66,8 +69,15 @@ def _create(
6669
An Array with data in (bands, height, width) order.
6770
6871
"""
69-
data_arr = np.asarray(data)
70-
mask_arr = np.asarray(mask) if mask is not None else None
72+
data_arr = np.asarray(data, copy=False)
73+
if mask is not None:
74+
mask_arr = np.asarray(mask, copy=False).astype(np.bool_, copy=False)
75+
assert mask_arr.ndim == 3 # noqa: S101, PLR2004
76+
assert mask_arr.shape[2] == 1 # noqa: S101
77+
# This assumes it's always (height, width, 1)
78+
mask_arr = np.squeeze(mask_arr, axis=2)
79+
else:
80+
mask_arr = None
7181

7282
assert data_arr.ndim == 3, f"Expected 3D array, got {data_arr.ndim}D" # noqa: S101, PLR2004
7383

src/async_geotiff/_geotiff.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import cached_property
55
from typing import TYPE_CHECKING, Self
66

7+
import async_tiff.enums
78
from affine import Affine
89
from async_tiff import TIFF
910
from async_tiff.enums import PhotometricInterpretation
@@ -13,13 +14,14 @@
1314
from async_geotiff._ifd import IFDReference
1415
from async_geotiff._overview import Overview
1516
from async_geotiff._transform import TransformMixin
16-
from async_geotiff.enums import Compression, Interleaving
1717

1818
if TYPE_CHECKING:
1919
from async_tiff import GeoKeyDirectory, ImageFileDirectory, ObspecInput
2020
from async_tiff.store import ObjectStore # type: ignore # noqa: PGH003
2121
from pyproj import CRS
2222

23+
from async_geotiff.enums import Compression, Interleaving
24+
2325

2426
@dataclass(frozen=True, init=False, kw_only=True, repr=False)
2527
class GeoTIFF(FetchTileMixin, TransformMixin):
@@ -374,7 +376,7 @@ def has_geokeys(ifd: ImageFileDirectory) -> bool:
374376
def is_mask_ifd(ifd: ImageFileDirectory) -> bool:
375377
"""Check if an IFD is a mask IFD."""
376378
return (
377-
ifd.compression == Compression.deflate
379+
ifd.compression == async_tiff.enums.CompressionMethod.Deflate
378380
and ifd.new_subfile_type is not None
379381
and ifd.new_subfile_type & 4 != 0
380382
and ifd.photometric_interpretation == PhotometricInterpretation.TransparencyMask

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __call__(
3434
name: str,
3535
*,
3636
variant: Variant = "rasterio",
37+
OVERVIEW_LEVEL: int | None = None, # noqa: N803
3738
**kwargs: Any, # noqa: ANN401
3839
) -> Generator[DatasetReader, None, None]: ...
3940

@@ -80,6 +81,7 @@ def _load(
8081
name: str,
8182
*,
8283
variant: Variant = "rasterio",
84+
OVERVIEW_LEVEL: int | None = None, # noqa: N803
8385
**kwargs: Any, # noqa: ANN401
8486
) -> Generator[DatasetReader, None, None]:
8587
path = f"{root_dir}/fixtures/geotiff-test-data/"
@@ -93,6 +95,10 @@ def _load(
9395
raise ValueError(f"Unknown variant: {variant}")
9496

9597
path = f"{path}{name}.tif"
98+
99+
if OVERVIEW_LEVEL is not None and "OVERVIEW_LEVEL" not in kwargs:
100+
kwargs["OVERVIEW_LEVEL"] = OVERVIEW_LEVEL
101+
96102
with rasterio.open(path, **kwargs) as ds:
97103
yield ds
98104

tests/test_fetch.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,62 @@ async def test_fetch_overview(
7979

8080
np.testing.assert_array_equal(tile.data, rasterio_data)
8181
assert tile.crs == geotiff.crs
82+
83+
84+
@pytest.mark.asyncio
85+
@pytest.mark.parametrize(
86+
("file_name", "variant"),
87+
[
88+
("maxar_opendata_yellowstone_visual", "vantor"),
89+
],
90+
)
91+
async def test_mask(
92+
load_geotiff: LoadGeoTIFF,
93+
load_rasterio: LoadRasterio,
94+
file_name: str,
95+
variant: Variant,
96+
) -> None:
97+
geotiff = await load_geotiff(file_name, variant=variant)
98+
99+
tile = await geotiff.fetch_tile(0, 0)
100+
101+
assert tile.mask is not None
102+
assert isinstance(tile.mask, np.ndarray)
103+
assert tile.mask.dtype == np.bool_
104+
assert tile.mask.shape == tile.data.shape[1:]
105+
106+
window = Window(0, 0, geotiff.tile_width, geotiff.tile_height)
107+
with load_rasterio(file_name, variant=variant) as rasterio_ds:
108+
mask = rasterio_ds.dataset_mask(window=window)
109+
110+
np.testing.assert_array_equal(tile.mask, mask.astype(np.bool_))
111+
112+
113+
@pytest.mark.asyncio
114+
@pytest.mark.parametrize(
115+
("file_name", "variant"),
116+
[
117+
("maxar_opendata_yellowstone_visual", "vantor"),
118+
],
119+
)
120+
async def test_mask_overview(
121+
load_geotiff: LoadGeoTIFF,
122+
load_rasterio: LoadRasterio,
123+
file_name: str,
124+
variant: Variant,
125+
) -> None:
126+
geotiff = await load_geotiff(file_name, variant=variant)
127+
overview = geotiff.overviews[0]
128+
129+
tile = await overview.fetch_tile(0, 0)
130+
131+
assert tile.mask is not None
132+
assert isinstance(tile.mask, np.ndarray)
133+
assert tile.mask.dtype == np.bool_
134+
assert tile.mask.shape == tile.data.shape[1:]
135+
136+
window = Window(0, 0, overview.tile_width, overview.tile_height)
137+
with load_rasterio(file_name, variant=variant, OVERVIEW_LEVEL=0) as rasterio_ds:
138+
mask = rasterio_ds.dataset_mask(window=window)
139+
140+
np.testing.assert_array_equal(tile.mask, mask.astype(np.bool_))

0 commit comments

Comments
 (0)