Skip to content

Commit 98ae83b

Browse files
authored
feat: Fetch tile data from GeoTIFF/Overview (#22)
* feat: Fetch tile data from GeoTIFF overview * Ensure we set top-level mask IFD correctly * Refactor fetch tiles * Add fetch_tile to top-level `GeoTIFF` * Add `fetch_tiles` to full-resolution GeoTIFF * move into type checking block * Add test for tile fetch * Add TransformMixin parent * Use mixin for defining tile fetching * Fix typing * cleaner association of masks to data ifds * Add IFDReference dataclass as abstraction to hold both the ifd and its index * Fix axis ordering * add test fetching overview * Update tests
1 parent 1013db7 commit 98ae83b

File tree

11 files changed

+453
-60
lines changed

11 files changed

+453
-60
lines changed

src/async_geotiff/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
[cogeo]: https://cogeo.org/
44
"""
55

6+
from ._array import Array
67
from ._geotiff import GeoTIFF
78
from ._overview import Overview
89
from ._version import __version__
910

10-
__all__ = ["GeoTIFF", "Overview", "__version__"]
11+
__all__ = ["Array", "GeoTIFF", "Overview", "__version__"]

src/async_geotiff/_array.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Self
5+
6+
import numpy as np
7+
from async_tiff.enums import PlanarConfiguration
8+
9+
from async_geotiff._transform import TransformMixin
10+
11+
if TYPE_CHECKING:
12+
from affine import Affine
13+
from async_tiff import Array as AsyncTiffArray
14+
from numpy.typing import NDArray
15+
from pyproj import CRS
16+
17+
18+
@dataclass(frozen=True, kw_only=True, eq=False)
19+
class Array(TransformMixin):
20+
"""An array representation of data from a GeoTIFF."""
21+
22+
data: NDArray
23+
"""The array data with shape (bands, height, width)."""
24+
25+
mask: NDArray | None
26+
"""The mask array with shape (height, width), if any."""
27+
28+
width: int
29+
"""The width of the array in pixels."""
30+
31+
height: int
32+
"""The height of the array in pixels."""
33+
34+
count: int
35+
"""The number of bands in the array."""
36+
37+
transform: Affine
38+
"""The affine transform mapping pixel coordinates to geographic coordinates."""
39+
40+
crs: CRS
41+
"""The coordinate reference system of the array."""
42+
43+
@classmethod
44+
def _create(
45+
cls,
46+
*,
47+
data: AsyncTiffArray,
48+
mask: AsyncTiffArray | None,
49+
planar_configuration: PlanarConfiguration,
50+
transform: Affine,
51+
crs: CRS,
52+
) -> Self:
53+
"""Create an Array from async_tiff data.
54+
55+
Handles axis reordering to ensure data is always in (bands, height, width)
56+
order, matching rasterio's convention.
57+
58+
Args:
59+
data: The decoded tile data from async_tiff.
60+
mask: The decoded mask data from async_tiff, if any.
61+
planar_configuration: The planar configuration of the source IFD.
62+
transform: The affine transform for this tile.
63+
crs: The coordinate reference system.
64+
65+
Returns:
66+
An Array with data in (bands, height, width) order.
67+
68+
"""
69+
data_arr = np.asarray(data)
70+
mask_arr = np.asarray(mask) if mask is not None else None
71+
72+
assert data_arr.ndim == 3, f"Expected 3D array, got {data_arr.ndim}D" # noqa: S101, PLR2004
73+
74+
# async_tiff returns data in the native TIFF order:
75+
# - Chunky (pixel interleaved): (height, width, bands)
76+
# - Planar (band interleaved): (bands, height, width)
77+
# We always want (bands, height, width) to match rasterio.
78+
if planar_configuration == PlanarConfiguration.Chunky:
79+
# Transpose from (H, W, C) to (C, H, W)
80+
data_arr = np.moveaxis(data_arr, -1, 0)
81+
82+
count, height, width = data_arr.shape
83+
84+
return cls(
85+
data=data_arr,
86+
mask=mask_arr,
87+
width=width,
88+
height=height,
89+
count=count,
90+
transform=transform,
91+
crs=crs,
92+
)

src/async_geotiff/_fetch.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import TYPE_CHECKING, Protocol
5+
6+
from affine import Affine
7+
8+
from async_geotiff import Array
9+
from async_geotiff._transform import HasTransform
10+
11+
if TYPE_CHECKING:
12+
from async_tiff import TIFF
13+
from async_tiff import Array as AsyncTiffArray
14+
from pyproj import CRS
15+
16+
from async_geotiff._ifd import IFDReference
17+
18+
19+
class HasTiffReference(HasTransform, Protocol):
20+
"""Protocol for objects that hold a TIFF reference and can request tiles."""
21+
22+
@property
23+
def _ifd(self) -> IFDReference:
24+
"""The data IFD for this image (index, IFD)."""
25+
...
26+
27+
@property
28+
def _mask_ifd(self) -> IFDReference | None:
29+
"""The mask IFD for this image (index, IFD), if any."""
30+
...
31+
32+
@property
33+
def _tiff(self) -> TIFF:
34+
"""A reference to the underlying TIFF object."""
35+
...
36+
37+
@property
38+
def crs(self) -> CRS:
39+
"""The coordinate reference system."""
40+
...
41+
42+
@property
43+
def tile_height(self) -> int:
44+
"""The height of tiles in pixels."""
45+
...
46+
47+
@property
48+
def tile_width(self) -> int:
49+
"""The width of tiles in pixels."""
50+
...
51+
52+
53+
class FetchTileMixin:
54+
"""Mixin for fetching tiles from a GeoTIFF.
55+
56+
Classes using this mixin must implement HasTiffReference.
57+
"""
58+
59+
async def fetch_tile(
60+
self: HasTiffReference,
61+
x: int,
62+
y: int,
63+
) -> Array:
64+
tile_fut = self._tiff.fetch_tile(x, y, self._ifd.index)
65+
66+
mask_data: AsyncTiffArray | None = None
67+
if self._mask_ifd is not None:
68+
mask_ifd_index = self._mask_ifd.index
69+
mask_fut = self._tiff.fetch_tile(x, y, mask_ifd_index)
70+
tile, mask = await asyncio.gather(tile_fut, mask_fut)
71+
tile_data, mask_data = await asyncio.gather(tile.decode(), mask.decode())
72+
else:
73+
tile = await tile_fut
74+
tile_data = await tile.decode()
75+
76+
tile_transform = self.transform * Affine.translation(
77+
x * self.tile_width,
78+
y * self.tile_height,
79+
)
80+
81+
return Array._create( # noqa: SLF001
82+
data=tile_data,
83+
mask=mask_data,
84+
planar_configuration=self._ifd.ifd.planar_configuration,
85+
crs=self.crs,
86+
transform=tile_transform,
87+
)
88+
89+
async def fetch_tiles(
90+
self: HasTiffReference,
91+
xs: list[int],
92+
ys: list[int],
93+
) -> list[Array]:
94+
"""Fetch multiple tiles from this overview.
95+
96+
Args:
97+
xs: The x coordinates of the tiles.
98+
ys: The y coordinates of the tiles.
99+
100+
"""
101+
tiles_fut = self._tiff.fetch_tiles(xs, ys, self._ifd.index)
102+
103+
decoded_masks: list[AsyncTiffArray | None] = [None] * len(xs)
104+
if self._mask_ifd is not None:
105+
mask_ifd_index = self._mask_ifd.index
106+
masks_fut = self._tiff.fetch_tiles(xs, ys, mask_ifd_index)
107+
tiles, masks = await asyncio.gather(tiles_fut, masks_fut)
108+
109+
decoded_tile_futs = [tile.decode() for tile in tiles]
110+
decoded_mask_futs = [mask.decode() for mask in masks]
111+
decoded_tiles = await asyncio.gather(*decoded_tile_futs)
112+
decoded_masks = await asyncio.gather(*decoded_mask_futs)
113+
else:
114+
tiles = await tiles_fut
115+
decoded_tiles = await asyncio.gather(*[tile.decode() for tile in tiles])
116+
117+
arrays: list[Array] = []
118+
for x, y, tile_data, mask_data in zip(
119+
xs,
120+
ys,
121+
decoded_tiles,
122+
decoded_masks,
123+
strict=True,
124+
):
125+
tile_transform = self.transform * Affine.translation(
126+
x * self.tile_width,
127+
y * self.tile_height,
128+
)
129+
array = Array._create( # noqa: SLF001
130+
data=tile_data,
131+
mask=mask_data,
132+
planar_configuration=self._ifd.ifd.planar_configuration,
133+
crs=self.crs,
134+
transform=tile_transform,
135+
)
136+
arrays.append(array)
137+
138+
return arrays

0 commit comments

Comments
 (0)