Skip to content

Commit 261b36d

Browse files
authored
Merge pull request #15 from Dinghye/6-tessera-crs-problem
6 tessera crs problem
2 parents 6bdd124 + 39aca36 commit 261b36d

File tree

2 files changed

+225
-9
lines changed

2 files changed

+225
-9
lines changed

src/rs_embed/embedders/precomputed_tessera.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,56 @@ def _tile_bounds(transform, w: int, h: int) -> Tuple[float, float, float, float]
8787
return left, bottom, right, top
8888

8989

90+
def _reproject_tile(
91+
hwc: np.ndarray,
92+
src_transform: Any,
93+
src_crs: str,
94+
dst_crs: str,
95+
target_res: Optional[Tuple[float, float]] = None,
96+
) -> Tuple[np.ndarray, Any]:
97+
"""Reproject an HWC embedding tile to *dst_crs* via nearest-neighbour.
98+
99+
If *target_res* is given as ``(pixel_width, pixel_height)`` (both
100+
positive), the output is snapped to that resolution so tiles can be
101+
mosaicked without sub-pixel drift.
102+
"""
103+
try:
104+
from rasterio.warp import reproject, Resampling, calculate_default_transform
105+
from rasterio.transform import array_bounds
106+
except ImportError as exc:
107+
raise ModelError(
108+
"Mixed-CRS mosaic requires rasterio. Install: pip install rasterio"
109+
) from exc
110+
111+
h, w, d = hwc.shape
112+
src_bounds = array_bounds(h, w, src_transform)
113+
114+
kwargs: Dict[str, Any] = {}
115+
if target_res is not None:
116+
kwargs["resolution"] = (abs(target_res[0]), abs(target_res[1]))
117+
118+
dst_transform, dst_w, dst_h = calculate_default_transform(
119+
src_crs, dst_crs, w, h, *src_bounds, **kwargs,
120+
)
121+
122+
dst_hwc = np.zeros((dst_h, dst_w, d), dtype=np.float32)
123+
for i in range(d):
124+
src_band = np.ascontiguousarray(hwc[:, :, i])
125+
dst_band = np.zeros((dst_h, dst_w), dtype=np.float32)
126+
reproject(
127+
source=src_band,
128+
destination=dst_band,
129+
src_transform=src_transform,
130+
src_crs=src_crs,
131+
dst_transform=dst_transform,
132+
dst_crs=dst_crs,
133+
resampling=Resampling.nearest,
134+
)
135+
dst_hwc[:, :, i] = dst_band
136+
137+
return dst_hwc, dst_transform
138+
139+
90140
def _reproject_bbox_4326_to(tile_crs_str: str, bbox: BBox) -> Tuple[float, float, float, float]:
91141
# returns (xmin, ymin, xmax, ymax) in tile CRS
92142
if str(tile_crs_str).upper() in ("EPSG:4326", "WGS84", "CRS:84"):
@@ -111,34 +161,55 @@ def _mosaic_and_crop_strict_roi(
111161
tiles_rows: list of (year, tile_lon, tile_lat, embedding_array, crs, transform)
112162
Return cropped CHW + meta.
113163
"""
164+
if not tiles_rows:
165+
raise ModelError("No tiles fetched; cannot mosaic.")
166+
167+
# --- detect mixed CRS and choose the most-common one as target ---
168+
crs_counts: Dict[str, int] = {}
169+
for _, _, _, _, crs, _ in tiles_rows:
170+
key = str(crs)
171+
crs_counts[key] = crs_counts.get(key, 0) + 1
172+
target_crs = max(crs_counts, key=lambda k: crs_counts[k])
173+
mixed_crs = len(crs_counts) > 1
174+
175+
# when CRS differ, lock the target pixel size from a native tile
176+
target_res: Optional[Tuple[float, float]] = None
177+
if mixed_crs:
178+
for _, _, _, _, crs, transform in tiles_rows:
179+
if str(crs) == target_crs:
180+
target_res = (abs(float(transform.a)), abs(float(transform.e)))
181+
break
182+
114183
# normalize + collect tile meta
115184
hwc_list = []
116-
crs0 = None
185+
crs0 = target_crs
117186
a0 = e0 = None
118187

119188
bounds_list = []
120189
for year, tlon, tlat, emb, crs, transform in tiles_rows:
121190
_assert_north_up(transform)
122191
hwc = _to_hwc(emb)
192+
193+
# reproject tile to target CRS when CRS differ
194+
if str(crs) != target_crs:
195+
hwc, transform = _reproject_tile(
196+
hwc, transform, str(crs), target_crs, target_res=target_res,
197+
)
198+
123199
h, w, d = hwc.shape
124200
left, bottom, right, top = _tile_bounds(transform, w, h)
125201

126-
if crs0 is None:
127-
crs0 = crs
202+
if a0 is None:
128203
a0 = float(transform.a)
129204
e0 = float(transform.e)
130205
else:
131-
if str(crs) != str(crs0):
132-
raise ModelError("Tiles have different CRS; cannot mosaic.")
133-
if abs(float(transform.a) - a0) > 1e-12 or abs(float(transform.e) - e0) > 1e-12:
206+
# reprojected tiles may have sub-pixel rounding; use 1e-6 tolerance
207+
if abs(float(transform.a) - a0) > 1e-6 or abs(float(transform.e) - e0) > 1e-6:
134208
raise ModelError("Tiles have different resolution; cannot mosaic without resampling.")
135209

136210
hwc_list.append((hwc, transform, (left, bottom, right, top)))
137211
bounds_list.append((left, bottom, right, top))
138212

139-
if crs0 is None:
140-
raise ModelError("No tiles fetched; cannot mosaic.")
141-
142213
# global mosaic bounds
143214
left = min(b[0] for b in bounds_list)
144215
bottom = min(b[1] for b in bounds_list)

tests/test_tessera_crs.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Tests for mixed-CRS mosaic in precomputed_tessera (GitHub issue #6).
2+
3+
When an ROI sits on a UTM zone boundary (e.g. lon = 120°) the tile
4+
store can return tiles in *two* EPSG codes (e.g. 32650 / 32651). The
5+
mosaic helper must reproject them to a common CRS instead of rejecting
6+
the operation.
7+
"""
8+
from __future__ import annotations
9+
10+
import numpy as np
11+
import pytest
12+
13+
from affine import Affine
14+
from pyproj import Transformer
15+
16+
from rs_embed.core.specs import BBox
17+
from rs_embed.core.errors import ModelError
18+
from rs_embed.embedders.precomputed_tessera import (
19+
_mosaic_and_crop_strict_roi,
20+
_reproject_tile,
21+
)
22+
23+
rasterio = pytest.importorskip("rasterio", reason="rasterio needed for mixed-CRS tests")
24+
25+
# ---------------------------------------------------------------------------
26+
# Helpers
27+
# ---------------------------------------------------------------------------
28+
29+
D = 64 # embedding dimension (must be in _to_hwc's allowlist)
30+
TILE_HW = 8 # small tiles for fast tests
31+
PX_M = 500.0 # 500 m pixels
32+
33+
34+
def _make_tile(lon: float, lat: float, epsg: str, value: float = 1.0):
35+
"""Create a synthetic (year, lon, lat, embedding, crs, transform) tuple."""
36+
tfm = Transformer.from_crs("EPSG:4326", epsg, always_xy=True)
37+
x, y = tfm.transform(lon, lat)
38+
39+
# north-up transform: origin at top-left of tile
40+
transform = Affine(PX_M, 0.0, x, 0.0, -PX_M, y + TILE_HW * PX_M)
41+
emb = np.full((TILE_HW, TILE_HW, D), value, dtype=np.float32)
42+
return (2021, lon, lat, emb, epsg, transform)
43+
44+
45+
# ---------------------------------------------------------------------------
46+
# Tests
47+
# ---------------------------------------------------------------------------
48+
49+
50+
class TestMixedCRSMosaic:
51+
"""Regression tests for issue #6 – UTM zone boundary."""
52+
53+
def test_same_crs_still_works(self):
54+
"""Sanity: single-CRS mosaic must still succeed."""
55+
tile_a = _make_tile(119.5, 30.0, "EPSG:32650", value=1.0)
56+
tile_b = _make_tile(119.6, 30.0, "EPSG:32650", value=2.0)
57+
58+
bbox = BBox(minlon=119.45, minlat=29.95, maxlon=119.65, maxlat=30.05)
59+
chw, meta = _mosaic_and_crop_strict_roi([tile_a, tile_b], bbox_4326=bbox)
60+
61+
assert chw.ndim == 3
62+
assert chw.shape[0] == D
63+
assert meta["tile_crs"] == "EPSG:32650"
64+
65+
def test_mixed_crs_at_utm_boundary(self):
66+
"""Two tiles in EPSG:32650 / 32651 at lon ≈ 120° must mosaic."""
67+
tile_a = _make_tile(119.95, 30.0, "EPSG:32650", value=1.0)
68+
tile_b = _make_tile(120.05, 30.0, "EPSG:32651", value=2.0)
69+
70+
bbox = BBox(minlon=119.90, minlat=29.95, maxlon=120.10, maxlat=30.05)
71+
chw, meta = _mosaic_and_crop_strict_roi([tile_a, tile_b], bbox_4326=bbox)
72+
73+
assert chw.ndim == 3
74+
assert chw.shape[0] == D
75+
# target CRS should be one of the two (whichever has more tiles, or
76+
# first in case of tie)
77+
assert meta["tile_crs"] in ("EPSG:32650", "EPSG:32651")
78+
79+
def test_mixed_crs_majority_wins(self):
80+
"""Target CRS is the most-common one among tiles."""
81+
tile_a = _make_tile(119.8, 30.0, "EPSG:32650", value=1.0)
82+
tile_b = _make_tile(119.9, 30.0, "EPSG:32650", value=1.5)
83+
tile_c = _make_tile(120.1, 30.0, "EPSG:32651", value=2.0)
84+
85+
bbox = BBox(minlon=119.75, minlat=29.95, maxlon=120.15, maxlat=30.05)
86+
chw, meta = _mosaic_and_crop_strict_roi(
87+
[tile_a, tile_b, tile_c], bbox_4326=bbox,
88+
)
89+
90+
assert chw.ndim == 3
91+
assert chw.shape[0] == D
92+
# 2 tiles in 32650, 1 in 32651 → target should be 32650
93+
assert meta["tile_crs"] == "EPSG:32650"
94+
95+
def test_empty_tiles_raises(self):
96+
"""No tiles at all should still raise."""
97+
bbox = BBox(minlon=119.0, minlat=29.0, maxlon=120.0, maxlat=30.0)
98+
with pytest.raises(ModelError, match="No tiles"):
99+
_mosaic_and_crop_strict_roi([], bbox_4326=bbox)
100+
101+
102+
class TestReprojectTile:
103+
"""Unit tests for _reproject_tile helper."""
104+
105+
def test_identity_reproject(self):
106+
"""Reprojecting to the same CRS should return data of similar shape."""
107+
hwc = np.ones((4, 4, D), dtype=np.float32)
108+
transform = Affine(PX_M, 0.0, 500_000.0, 0.0, -PX_M, 3_500_000.0)
109+
110+
out_hwc, out_tf = _reproject_tile(
111+
hwc, transform, "EPSG:32650", "EPSG:32650",
112+
)
113+
114+
assert out_hwc.ndim == 3
115+
assert out_hwc.shape[-1] == D
116+
np.testing.assert_allclose(out_hwc, 1.0, atol=1e-5)
117+
118+
def test_cross_zone_reproject(self):
119+
"""Reproject from zone 50 to zone 51 preserves data values."""
120+
hwc = np.full((4, 4, D), 42.0, dtype=np.float32)
121+
transform = Affine(PX_M, 0.0, 800_000.0, 0.0, -PX_M, 3_500_000.0)
122+
123+
out_hwc, out_tf = _reproject_tile(
124+
hwc, transform, "EPSG:32650", "EPSG:32651",
125+
)
126+
127+
assert out_hwc.ndim == 3
128+
assert out_hwc.shape[-1] == D
129+
# nearest-neighbour: non-zero pixels should carry the original value
130+
nonzero = out_hwc[out_hwc != 0]
131+
if nonzero.size:
132+
np.testing.assert_allclose(nonzero, 42.0, atol=1e-5)
133+
134+
def test_target_res_snaps_resolution(self):
135+
"""Providing target_res should lock the output pixel size."""
136+
hwc = np.ones((4, 4, D), dtype=np.float32)
137+
transform = Affine(PX_M, 0.0, 500_000.0, 0.0, -PX_M, 3_500_000.0)
138+
139+
target_res = (PX_M, PX_M)
140+
out_hwc, out_tf = _reproject_tile(
141+
hwc, transform, "EPSG:32650", "EPSG:32651", target_res=target_res,
142+
)
143+
144+
assert abs(abs(float(out_tf.a)) - PX_M) < 1e-3
145+
assert abs(abs(float(out_tf.e)) - PX_M) < 1e-3

0 commit comments

Comments
 (0)