diff --git a/src/async_geotiff/__init__.py b/src/async_geotiff/__init__.py index e8763d2..15c7878 100644 --- a/src/async_geotiff/__init__.py +++ b/src/async_geotiff/__init__.py @@ -3,8 +3,9 @@ [cogeo]: https://cogeo.org/ """ +from ._array import Array from ._geotiff import GeoTIFF from ._overview import Overview from ._version import __version__ -__all__ = ["GeoTIFF", "Overview", "__version__"] +__all__ = ["Array", "GeoTIFF", "Overview", "__version__"] diff --git a/src/async_geotiff/_array.py b/src/async_geotiff/_array.py new file mode 100644 index 0000000..8530118 --- /dev/null +++ b/src/async_geotiff/_array.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Self + +import numpy as np +from async_tiff.enums import PlanarConfiguration + +from async_geotiff._transform import TransformMixin + +if TYPE_CHECKING: + from affine import Affine + from async_tiff import Array as AsyncTiffArray + from numpy.typing import NDArray + from pyproj import CRS + + +@dataclass(frozen=True, kw_only=True, eq=False) +class Array(TransformMixin): + """An array representation of data from a GeoTIFF.""" + + data: NDArray + """The array data with shape (bands, height, width).""" + + mask: NDArray | None + """The mask array with shape (height, width), if any.""" + + width: int + """The width of the array in pixels.""" + + height: int + """The height of the array in pixels.""" + + count: int + """The number of bands in the array.""" + + transform: Affine + """The affine transform mapping pixel coordinates to geographic coordinates.""" + + crs: CRS + """The coordinate reference system of the array.""" + + @classmethod + def _create( + cls, + *, + data: AsyncTiffArray, + mask: AsyncTiffArray | None, + planar_configuration: PlanarConfiguration, + transform: Affine, + crs: CRS, + ) -> Self: + """Create an Array from async_tiff data. + + Handles axis reordering to ensure data is always in (bands, height, width) + order, matching rasterio's convention. + + Args: + data: The decoded tile data from async_tiff. + mask: The decoded mask data from async_tiff, if any. + planar_configuration: The planar configuration of the source IFD. + transform: The affine transform for this tile. + crs: The coordinate reference system. + + Returns: + An Array with data in (bands, height, width) order. + + """ + data_arr = np.asarray(data) + mask_arr = np.asarray(mask) if mask is not None else None + + assert data_arr.ndim == 3, f"Expected 3D array, got {data_arr.ndim}D" # noqa: S101, PLR2004 + + # async_tiff returns data in the native TIFF order: + # - Chunky (pixel interleaved): (height, width, bands) + # - Planar (band interleaved): (bands, height, width) + # We always want (bands, height, width) to match rasterio. + if planar_configuration == PlanarConfiguration.Chunky: + # Transpose from (H, W, C) to (C, H, W) + data_arr = np.moveaxis(data_arr, -1, 0) + + count, height, width = data_arr.shape + + return cls( + data=data_arr, + mask=mask_arr, + width=width, + height=height, + count=count, + transform=transform, + crs=crs, + ) diff --git a/src/async_geotiff/_fetch.py b/src/async_geotiff/_fetch.py new file mode 100644 index 0000000..cec557f --- /dev/null +++ b/src/async_geotiff/_fetch.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Protocol + +from affine import Affine + +from async_geotiff import Array +from async_geotiff._transform import HasTransform + +if TYPE_CHECKING: + from async_tiff import TIFF + from async_tiff import Array as AsyncTiffArray + from pyproj import CRS + + from async_geotiff._ifd import IFDReference + + +class HasTiffReference(HasTransform, Protocol): + """Protocol for objects that hold a TIFF reference and can request tiles.""" + + @property + def _ifd(self) -> IFDReference: + """The data IFD for this image (index, IFD).""" + ... + + @property + def _mask_ifd(self) -> IFDReference | None: + """The mask IFD for this image (index, IFD), if any.""" + ... + + @property + def _tiff(self) -> TIFF: + """A reference to the underlying TIFF object.""" + ... + + @property + def crs(self) -> CRS: + """The coordinate reference system.""" + ... + + @property + def tile_height(self) -> int: + """The height of tiles in pixels.""" + ... + + @property + def tile_width(self) -> int: + """The width of tiles in pixels.""" + ... + + +class FetchTileMixin: + """Mixin for fetching tiles from a GeoTIFF. + + Classes using this mixin must implement HasTiffReference. + """ + + async def fetch_tile( + self: HasTiffReference, + x: int, + y: int, + ) -> Array: + tile_fut = self._tiff.fetch_tile(x, y, self._ifd.index) + + mask_data: AsyncTiffArray | None = None + if self._mask_ifd is not None: + mask_ifd_index = self._mask_ifd.index + mask_fut = self._tiff.fetch_tile(x, y, mask_ifd_index) + tile, mask = await asyncio.gather(tile_fut, mask_fut) + tile_data, mask_data = await asyncio.gather(tile.decode(), mask.decode()) + else: + tile = await tile_fut + tile_data = await tile.decode() + + tile_transform = self.transform * Affine.translation( + x * self.tile_width, + y * self.tile_height, + ) + + return Array._create( # noqa: SLF001 + data=tile_data, + mask=mask_data, + planar_configuration=self._ifd.ifd.planar_configuration, + crs=self.crs, + transform=tile_transform, + ) + + async def fetch_tiles( + self: HasTiffReference, + xs: list[int], + ys: list[int], + ) -> list[Array]: + """Fetch multiple tiles from this overview. + + Args: + xs: The x coordinates of the tiles. + ys: The y coordinates of the tiles. + + """ + tiles_fut = self._tiff.fetch_tiles(xs, ys, self._ifd.index) + + decoded_masks: list[AsyncTiffArray | None] = [None] * len(xs) + if self._mask_ifd is not None: + mask_ifd_index = self._mask_ifd.index + masks_fut = self._tiff.fetch_tiles(xs, ys, mask_ifd_index) + tiles, masks = await asyncio.gather(tiles_fut, masks_fut) + + decoded_tile_futs = [tile.decode() for tile in tiles] + decoded_mask_futs = [mask.decode() for mask in masks] + decoded_tiles = await asyncio.gather(*decoded_tile_futs) + decoded_masks = await asyncio.gather(*decoded_mask_futs) + else: + tiles = await tiles_fut + decoded_tiles = await asyncio.gather(*[tile.decode() for tile in tiles]) + + arrays: list[Array] = [] + for x, y, tile_data, mask_data in zip( + xs, + ys, + decoded_tiles, + decoded_masks, + strict=True, + ): + tile_transform = self.transform * Affine.translation( + x * self.tile_width, + y * self.tile_height, + ) + array = Array._create( # noqa: SLF001 + data=tile_data, + mask=mask_data, + planar_configuration=self._ifd.ifd.planar_configuration, + crs=self.crs, + transform=tile_transform, + ) + arrays.append(array) + + return arrays diff --git a/src/async_geotiff/_geotiff.py b/src/async_geotiff/_geotiff.py index f6b995f..ab2084e 100644 --- a/src/async_geotiff/_geotiff.py +++ b/src/async_geotiff/_geotiff.py @@ -9,30 +9,45 @@ from async_tiff.enums import PhotometricInterpretation from async_geotiff._crs import crs_from_geo_keys +from async_geotiff._fetch import FetchTileMixin +from async_geotiff._ifd import IFDReference from async_geotiff._overview import Overview from async_geotiff._transform import TransformMixin from async_geotiff.enums import Compression, Interleaving if TYPE_CHECKING: - import pyproj from async_tiff import GeoKeyDirectory, ImageFileDirectory, ObspecInput from async_tiff.store import ObjectStore # type: ignore # noqa: PGH003 + from pyproj import CRS @dataclass(frozen=True, init=False, kw_only=True, repr=False) -class GeoTIFF(TransformMixin): +class GeoTIFF(FetchTileMixin, TransformMixin): """A class representing a GeoTIFF image.""" + _crs: CRS | None = None + """A cached CRS instance. + + We don't use functools.cached_property on the `crs` attribute because of typing + issues. + """ + _tiff: TIFF """The underlying async-tiff TIFF instance that we wrap. """ - _primary_ifd: ImageFileDirectory = field(init=False) + _primary_ifd: IFDReference = field(init=False) """The primary (first) IFD of the GeoTIFF. Some tags, like most geo tags, only exist on the primary IFD. """ + _mask_ifd: IFDReference | None = None + """The mask IFD of the full-resolution GeoTIFF, if any. + + (positional index of the IFD in the TIFF file, IFD object) + """ + _gkd: GeoKeyDirectory = field(init=False) """The GeoKeyDirectory of the primary IFD. """ @@ -41,6 +56,11 @@ class GeoTIFF(TransformMixin): """A list of overviews for the GeoTIFF. """ + @property + def _ifd(self) -> IFDReference: + """An alias for the primary IFD to satisfy _fetch protocol.""" + return self._primary_ifd + def __init__(self, tiff: TIFF) -> None: """Create a GeoTIFF from an existing TIFF instance.""" first_ifd = tiff.ifds[0] @@ -55,32 +75,35 @@ def __init__(self, tiff: TIFF) -> None: # We use object.__setattr__ because the dataclass is frozen object.__setattr__(self, "_tiff", tiff) - object.__setattr__(self, "_primary_ifd", first_ifd) + object.__setattr__(self, "_primary_ifd", IFDReference(index=0, ifd=first_ifd)) object.__setattr__(self, "_gkd", gkd) - # Skip the first IFD, since it's the primary image - ifd_idx = 1 + # Separate data IFDs and mask IFDs (skip the primary IFD at index 0) + # Data IFDs are indexed by (width, height) for matching with masks + data_ifds: dict[tuple[int, int], IFDReference] = {} + mask_ifds: dict[tuple[int, int], IFDReference] = {} + + for idx, ifd in enumerate(tiff.ifds[1:], start=1): + dims = (ifd.image_width, ifd.image_height) + if is_mask_ifd(ifd): + mask_ifds[dims] = IFDReference(index=idx, ifd=ifd) + else: + data_ifds[dims] = IFDReference(index=idx, ifd=ifd) + + # Find and set the mask for the primary IFD (matches primary dimensions) + if primary_mask_ifd := mask_ifds.get( + (first_ifd.image_width, first_ifd.image_height), + ): + object.__setattr__(self, "_mask_ifd", primary_mask_ifd) + + # Build overviews, sorted by resolution (highest to lowest, i.e., largest first) + # Sort by width * height descending + sorted_dims = sorted(data_ifds.keys(), key=lambda d: d[0] * d[1], reverse=True) + overviews: list[Overview] = [] - while True: - try: - data_ifd = (ifd_idx, tiff.ifds[ifd_idx]) - except IndexError: - # No more IFDs - break - - ifd_idx += 1 - - mask_ifd = None - next_ifd = None - try: - next_ifd = tiff.ifds[ifd_idx] - except IndexError: - # No more IFDs - pass - finally: - if next_ifd is not None and is_mask_ifd(next_ifd): - mask_ifd = (ifd_idx, next_ifd) - ifd_idx += 1 + for dims in sorted_dims: + data_ifd = data_ifds[dims] + mask_ifd = mask_ifds.get(dims) ovr = Overview._create( # noqa: SLF001 geotiff=self, @@ -203,10 +226,15 @@ def count(self) -> int: """The number of raster bands in the full image.""" raise NotImplementedError - @cached_property - def crs(self) -> pyproj.CRS: + @property + def crs(self) -> CRS: """The dataset's coordinate reference system.""" - return crs_from_geo_keys(self._gkd) + if self._crs is not None: + return self._crs + + crs = crs_from_geo_keys(self._gkd) + object.__setattr__(self, "_crs", crs) + return crs @property def dtypes(self) -> list[str]: @@ -219,14 +247,14 @@ def dtypes(self) -> list[str]: @property def height(self) -> int: """The height (number of rows) of the full image.""" - return self._primary_ifd.image_height + return self._primary_ifd.ifd.image_height def indexes(self) -> list[int]: """Return the 1-based indexes of each band in the dataset. For a 3-band dataset, this property will be [1, 2, 3]. """ - return list(range(1, self._primary_ifd.samples_per_pixel + 1)) + return list(range(1, self._primary_ifd.ifd.samples_per_pixel + 1)) @property def interleaving(self) -> Interleaving: @@ -243,7 +271,7 @@ def is_tiled(self) -> bool: @property def nodata(self) -> float | None: """The dataset's single nodata value.""" - nodata = self._primary_ifd.gdal_nodata + nodata = self._primary_ifd.ifd.gdal_nodata if nodata is None: return None @@ -272,15 +300,25 @@ def shape(self) -> tuple[int, int]: """Get the shape (height, width) of the full image.""" return (self.height, self.width) - @cached_property + @property + def tile_height(self) -> int: + """The height in pixels per tile of the image.""" + return self._primary_ifd.ifd.tile_height or self.height + + @property + def tile_width(self) -> int: + """The width in pixels per tile of the image.""" + return self._primary_ifd.ifd.tile_width or self.width + + @property def transform(self) -> Affine: """Return the dataset's georeferencing transformation matrix. This transform maps pixel row/column coordinates to coordinates in the dataset's CRS. """ - if (tie_points := self._primary_ifd.model_tiepoint) and ( - model_scale := self._primary_ifd.model_pixel_scale + if (tie_points := self._primary_ifd.ifd.model_tiepoint) and ( + model_scale := self._primary_ifd.ifd.model_pixel_scale ): x_origin = tie_points[3] y_origin = tie_points[4] @@ -289,7 +327,7 @@ def transform(self) -> Affine: return Affine(x_resolution, 0, x_origin, 0, y_resolution, y_origin) - if model_transformation := self._primary_ifd.model_transformation: + if model_transformation := self._primary_ifd.ifd.model_transformation: # ModelTransformation is a 4x4 matrix in row-major order # [0 1 2 3 ] [a b 0 c] # [4 5 6 7 ] = [d e 0 f] @@ -320,7 +358,7 @@ def transform(self) -> Affine: @property def width(self) -> int: """The width (number of columns) of the full image.""" - return self._primary_ifd.image_width + return self._primary_ifd.ifd.image_width def has_geokeys(ifd: ImageFileDirectory) -> bool: diff --git a/src/async_geotiff/_ifd.py b/src/async_geotiff/_ifd.py new file mode 100644 index 0000000..a351eb4 --- /dev/null +++ b/src/async_geotiff/_ifd.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from async_tiff import ImageFileDirectory + + +@dataclass(frozen=True, kw_only=True, repr=False) +class IFDReference: + """A reference to an Image File Directory (IFD) in a TIFF file.""" + + index: int + """The positional index of the IFD in the TIFF file.""" + + ifd: ImageFileDirectory + """The IFD object itself.""" diff --git a/src/async_geotiff/_overview.py b/src/async_geotiff/_overview.py index feab4d7..c1407be 100644 --- a/src/async_geotiff/_overview.py +++ b/src/async_geotiff/_overview.py @@ -1,21 +1,25 @@ from __future__ import annotations from dataclasses import dataclass -from functools import cached_property from typing import TYPE_CHECKING from affine import Affine +from async_geotiff._fetch import FetchTileMixin from async_geotiff._transform import TransformMixin if TYPE_CHECKING: - from async_tiff import GeoKeyDirectory, ImageFileDirectory + from async_tiff import TIFF, GeoKeyDirectory + from pyproj import CRS from async_geotiff import GeoTIFF + from async_geotiff._ifd import IFDReference + +# ruff: noqa: SLF001 @dataclass(init=False, frozen=True, kw_only=True, eq=False, repr=False) -class Overview(TransformMixin): +class Overview(FetchTileMixin, TransformMixin): """An overview level of a Cloud-Optimized GeoTIFF image.""" _geotiff: GeoTIFF @@ -26,13 +30,13 @@ class Overview(TransformMixin): """The GeoKeyDirectory of the primary IFD. """ - _ifd: tuple[int, ImageFileDirectory] + _ifd: IFDReference """The IFD for this overview level. (positional index of the IFD in the TIFF file, IFD object) """ - _mask_ifd: tuple[int, ImageFileDirectory] | None + _mask_ifd: IFDReference | None """The IFD for the mask associated with this overview level, if any. (positional index of the IFD in the TIFF file, IFD object) @@ -44,8 +48,8 @@ def _create( *, geotiff: GeoTIFF, gkd: GeoKeyDirectory, - ifd: tuple[int, ImageFileDirectory], - mask_ifd: tuple[int, ImageFileDirectory] | None, + ifd: IFDReference, + mask_ifd: IFDReference | None, ) -> Overview: instance = cls.__new__(cls) @@ -57,13 +61,33 @@ def _create( return instance + @property + def _tiff(self) -> TIFF: + """A reference to the underlying TIFF object.""" + return self._geotiff._tiff + + @property + def crs(self) -> CRS: + """The coordinate reference system of the overview.""" + return self._geotiff.crs + @property def height(self) -> int: """The height of the overview in pixels.""" - return self._ifd[1].image_height + return self._ifd.ifd.image_height - @cached_property - def transform(self) -> Affine: + @property + def tile_height(self) -> int: + """The height in pixels per tile of the overview.""" + return self._ifd.ifd.tile_height or self.height + + @property + def tile_width(self) -> int: + """The width in pixels per tile of the overview.""" + return self._ifd.ifd.tile_width or self.width + + @property + def transform(self) -> Affine: # type: ignore[override] """The affine transform mapping pixel coordinates to geographic coordinates. Returns: @@ -72,9 +96,9 @@ def transform(self) -> Affine: """ full_transform = self._geotiff.transform - overview_width = self._ifd[1].image_width + overview_width = self._ifd.ifd.image_width full_width = self._geotiff.width - overview_height = self._ifd[1].image_height + overview_height = self._ifd.ifd.image_height full_height = self._geotiff.height scale_x = full_width / overview_width @@ -85,4 +109,4 @@ def transform(self) -> Affine: @property def width(self) -> int: """The width of the overview in pixels.""" - return self._ifd[1].image_width + return self._ifd.ifd.image_width diff --git a/src/async_geotiff/_transform.py b/src/async_geotiff/_transform.py index 08f928d..e28b920 100644 --- a/src/async_geotiff/_transform.py +++ b/src/async_geotiff/_transform.py @@ -21,8 +21,7 @@ def transform(self) -> Affine: ... class TransformMixin: """Mixin providing coordinate transformation methods. - Classes using this mixin must have a `transform` property that returns - an `Affine` transformation matrix. + Classes using this mixin must implement HasTransform. """ def index( diff --git a/src/async_geotiff/tms.py b/src/async_geotiff/tms.py index 35d287b..1ebd707 100644 --- a/src/async_geotiff/tms.py +++ b/src/async_geotiff/tms.py @@ -47,8 +47,8 @@ def generate_tms( bounds = geotiff.bounds crs = geotiff.crs tr = geotiff.transform - blockxsize = geotiff._primary_ifd.tile_width # noqa: SLF001 - blockysize = geotiff._primary_ifd.tile_height # noqa: SLF001 + blockxsize = geotiff._primary_ifd.ifd.tile_width # noqa: SLF001 + blockysize = geotiff._primary_ifd.ifd.tile_height # noqa: SLF001 if blockxsize is None or blockysize is None: raise ValueError("GeoTIFF must be tiled to generate a TMS.") @@ -63,8 +63,8 @@ def generate_tms( for idx, overview in enumerate(reversed(geotiff.overviews)): overview_tr = overview.transform - blockxsize = overview._ifd[1].tile_width # noqa: SLF001 - blockysize = overview._ifd[1].tile_height # noqa: SLF001 + blockxsize = overview._ifd.ifd.tile_width # noqa: SLF001 + blockysize = overview._ifd.ifd.tile_height # noqa: SLF001 if blockxsize is None or blockysize is None: raise ValueError("GeoTIFF overviews must be tiled to generate a TMS.") diff --git a/tests/conftest.py b/tests/conftest.py index ef2771d..8bdfeba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Literal, Protocol +from typing import TYPE_CHECKING, Any, Literal, Protocol import pytest import rasterio @@ -34,6 +34,7 @@ def __call__( name: str, *, variant: Variant = "rasterio", + **kwargs: Any, # noqa: ANN401 ) -> Generator[DatasetReader, None, None]: ... @@ -79,6 +80,7 @@ def _load( name: str, *, variant: Variant = "rasterio", + **kwargs: Any, # noqa: ANN401 ) -> Generator[DatasetReader, None, None]: path = f"{root_dir}/fixtures/geotiff-test-data/" if variant == "rasterio": @@ -91,7 +93,7 @@ def _load( raise ValueError(f"Unknown variant: {variant}") path = f"{path}{name}.tif" - with rasterio.open(path) as ds: + with rasterio.open(path, **kwargs) as ds: yield ds return _load diff --git a/tests/test_fetch.py b/tests/test_fetch.py new file mode 100644 index 0000000..221d2fe --- /dev/null +++ b/tests/test_fetch.py @@ -0,0 +1,81 @@ +"""Test fetching tiles from a GeoTIFF.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from rasterio.windows import Window + +if TYPE_CHECKING: + from .conftest import LoadGeoTIFF, LoadRasterio, Variant + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("file_name", "variant"), + [ + # TODO: support LERC + # https://github.com/developmentseed/async-geotiff/issues/34 + # ("float32_1band_lerc_block32", "rasterio"), # noqa: ERA001 + ("uint16_1band_lzw_block128_predictor2", "rasterio"), + ("uint8_rgb_deflate_block64_cog", "rasterio"), + ("uint8_1band_deflate_block128_unaligned", "rasterio"), + # TODO: debug incorrect data length + # https://github.com/developmentseed/async-tiff/issues/202 + # ("maxar_opendata_yellowstone_visual", "vantor"), # noqa: ERA001 + ("nlcd_landcover", "nlcd"), + ], +) +async def test_fetch( + load_geotiff: LoadGeoTIFF, + load_rasterio: LoadRasterio, + file_name: str, + variant: Variant, +) -> None: + geotiff = await load_geotiff(file_name, variant=variant) + + tile = await geotiff.fetch_tile(0, 0) + + window = Window(0, 0, geotiff.tile_width, geotiff.tile_height) + with load_rasterio(file_name, variant=variant) as rasterio_ds: + rasterio_data = rasterio_ds.read(window=window) + + np.testing.assert_array_equal(tile.data, rasterio_data) + assert tile.crs == geotiff.crs + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("file_name", "variant"), + [ + # TODO: support LERC + # https://github.com/developmentseed/async-geotiff/issues/34 + # ("float32_1band_lerc_block32", "rasterio"), # noqa: ERA001 + ("uint16_1band_lzw_block128_predictor2", "rasterio"), + ("uint8_rgb_deflate_block64_cog", "rasterio"), + ("uint8_1band_deflate_block128_unaligned", "rasterio"), + # TODO: debug incorrect data length + # https://github.com/developmentseed/async-tiff/issues/202 + # ("maxar_opendata_yellowstone_visual", "vantor"), # noqa: ERA001 + ("nlcd_landcover", "nlcd"), + ], +) +async def test_fetch_overview( + load_geotiff: LoadGeoTIFF, + load_rasterio: LoadRasterio, + file_name: str, + variant: Variant, +) -> None: + geotiff = await load_geotiff(file_name, variant=variant) + overview = geotiff.overviews[0] + + tile = await overview.fetch_tile(0, 0) + + window = Window(0, 0, overview.tile_width, overview.tile_height) + with load_rasterio(file_name, variant=variant, OVERVIEW_LEVEL=0) as rasterio_ds: + rasterio_data = rasterio_ds.read(window=window) + + np.testing.assert_array_equal(tile.data, rasterio_data) + assert tile.crs == geotiff.crs diff --git a/uv.lock b/uv.lock index a77e068..b874e38 100644 --- a/uv.lock +++ b/uv.lock @@ -47,7 +47,7 @@ wheels = [ [[package]] name = "async-geotiff" -version = "0.1.0b2" +version = "0.1.0b3" source = { editable = "." } dependencies = [ { name = "affine" },