Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/async_geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
[cogeo]: https://cogeo.org/
"""

from ._array import Array
from ._geotiff import GeoTIFF
from ._overview import Overview
from ._version import __version__

__all__ = ["GeoTIFF", "Overview", "__version__"]
__all__ = ["Array", "GeoTIFF", "Overview", "__version__"]
92 changes: 92 additions & 0 deletions src/async_geotiff/_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Self

import numpy as np
from async_tiff.enums import PlanarConfiguration

from async_geotiff._transform import TransformMixin

if TYPE_CHECKING:
from affine import Affine
from async_tiff import Array as AsyncTiffArray
from numpy.typing import NDArray
from pyproj import CRS


@dataclass(frozen=True, kw_only=True, eq=False)
class Array(TransformMixin):
"""An array representation of data from a GeoTIFF."""

data: NDArray
"""The array data with shape (bands, height, width)."""

mask: NDArray | None
"""The mask array with shape (height, width), if any."""

width: int
"""The width of the array in pixels."""

height: int
"""The height of the array in pixels."""

count: int
"""The number of bands in the array."""

transform: Affine
"""The affine transform mapping pixel coordinates to geographic coordinates."""

crs: CRS
"""The coordinate reference system of the array."""

@classmethod
def _create(
cls,
*,
data: AsyncTiffArray,
mask: AsyncTiffArray | None,
planar_configuration: PlanarConfiguration,
transform: Affine,
crs: CRS,
) -> Self:
"""Create an Array from async_tiff data.

Handles axis reordering to ensure data is always in (bands, height, width)
order, matching rasterio's convention.

Args:
data: The decoded tile data from async_tiff.
mask: The decoded mask data from async_tiff, if any.
planar_configuration: The planar configuration of the source IFD.
transform: The affine transform for this tile.
crs: The coordinate reference system.

Returns:
An Array with data in (bands, height, width) order.

"""
data_arr = np.asarray(data)
mask_arr = np.asarray(mask) if mask is not None else None

assert data_arr.ndim == 3, f"Expected 3D array, got {data_arr.ndim}D" # noqa: S101, PLR2004

# async_tiff returns data in the native TIFF order:
# - Chunky (pixel interleaved): (height, width, bands)
# - Planar (band interleaved): (bands, height, width)
# We always want (bands, height, width) to match rasterio.
if planar_configuration == PlanarConfiguration.Chunky:
# Transpose from (H, W, C) to (C, H, W)
data_arr = np.moveaxis(data_arr, -1, 0)

count, height, width = data_arr.shape

return cls(
data=data_arr,
mask=mask_arr,
width=width,
height=height,
count=count,
transform=transform,
crs=crs,
)
138 changes: 138 additions & 0 deletions src/async_geotiff/_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Protocol

from affine import Affine

from async_geotiff import Array
from async_geotiff._transform import HasTransform

if TYPE_CHECKING:
from async_tiff import TIFF
from async_tiff import Array as AsyncTiffArray
from pyproj import CRS

from async_geotiff._ifd import IFDReference


class HasTiffReference(HasTransform, Protocol):
"""Protocol for objects that hold a TIFF reference and can request tiles."""

@property
def _ifd(self) -> IFDReference:
"""The data IFD for this image (index, IFD)."""
...

@property
def _mask_ifd(self) -> IFDReference | None:
"""The mask IFD for this image (index, IFD), if any."""
...

@property
def _tiff(self) -> TIFF:
"""A reference to the underlying TIFF object."""
...

@property
def crs(self) -> CRS:
"""The coordinate reference system."""
...

@property
def tile_height(self) -> int:
"""The height of tiles in pixels."""
...

@property
def tile_width(self) -> int:
"""The width of tiles in pixels."""
...


class FetchTileMixin:
"""Mixin for fetching tiles from a GeoTIFF.

Classes using this mixin must implement HasTiffReference.
"""

async def fetch_tile(
self: HasTiffReference,
x: int,
y: int,
) -> Array:
tile_fut = self._tiff.fetch_tile(x, y, self._ifd.index)

mask_data: AsyncTiffArray | None = None
if self._mask_ifd is not None:
mask_ifd_index = self._mask_ifd.index
mask_fut = self._tiff.fetch_tile(x, y, mask_ifd_index)
tile, mask = await asyncio.gather(tile_fut, mask_fut)
tile_data, mask_data = await asyncio.gather(tile.decode(), mask.decode())
else:
tile = await tile_fut
tile_data = await tile.decode()

tile_transform = self.transform * Affine.translation(
x * self.tile_width,
y * self.tile_height,
)

return Array._create( # noqa: SLF001
data=tile_data,
mask=mask_data,
planar_configuration=self._ifd.ifd.planar_configuration,
crs=self.crs,
transform=tile_transform,
)

async def fetch_tiles(
self: HasTiffReference,
xs: list[int],
ys: list[int],
) -> list[Array]:
"""Fetch multiple tiles from this overview.

Args:
xs: The x coordinates of the tiles.
ys: The y coordinates of the tiles.

"""
tiles_fut = self._tiff.fetch_tiles(xs, ys, self._ifd.index)

decoded_masks: list[AsyncTiffArray | None] = [None] * len(xs)
if self._mask_ifd is not None:
mask_ifd_index = self._mask_ifd.index
masks_fut = self._tiff.fetch_tiles(xs, ys, mask_ifd_index)
tiles, masks = await asyncio.gather(tiles_fut, masks_fut)

decoded_tile_futs = [tile.decode() for tile in tiles]
decoded_mask_futs = [mask.decode() for mask in masks]
decoded_tiles = await asyncio.gather(*decoded_tile_futs)
decoded_masks = await asyncio.gather(*decoded_mask_futs)
else:
tiles = await tiles_fut
decoded_tiles = await asyncio.gather(*[tile.decode() for tile in tiles])

arrays: list[Array] = []
for x, y, tile_data, mask_data in zip(
xs,
ys,
decoded_tiles,
decoded_masks,
strict=True,
):
tile_transform = self.transform * Affine.translation(
x * self.tile_width,
y * self.tile_height,
)
array = Array._create( # noqa: SLF001
data=tile_data,
mask=mask_data,
planar_configuration=self._ifd.ifd.planar_configuration,
crs=self.crs,
transform=tile_transform,
)
arrays.append(array)

return arrays
Loading
Loading