Skip to content

Commit a93d621

Browse files
authored
use cf-xarray to infer cell geometries (#14)
* refactor to allow reusing the example datasets * type hints * draft implementation of the inferring of cell geometries * implement and test the actual derivation of geometries * refuse guessing for unstructured * explicitly construct the variables and geometries * also test 2d-curvilinear * remove the xfailed `1d-unstructured` test since we refuse to guess, it doesn't make sense to even test this.
1 parent 4f363fd commit a93d621

File tree

2 files changed

+215
-60
lines changed

2 files changed

+215
-60
lines changed

python/grid_indexing/grids.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
from typing import Literal
2+
13
import cf_xarray # noqa: F401
24
import numpy as np
5+
import shapely
6+
import xarray as xr
7+
from numpy.typing import ArrayLike
38

49

5-
def is_meshgrid(coord1, coord2):
10+
def is_meshgrid(coord1: ArrayLike, coord2: ArrayLike):
611
return (
712
np.all(coord1[0, :] == coord1[1, :]) and np.all(coord2[:, 0] == coord2[:, 1])
813
) or (np.all(coord1[:, 0] == coord1[:, 1]) and np.all(coord2[0, :] == coord2[1, :]))
914

1015

11-
def infer_grid_type(ds):
16+
def infer_grid_type(ds: xr.Dataset):
1217
# grid types (all geographic):
1318
# - 2d crs (affine transform)
1419
# - 1d orthogonal (rectilinear)
@@ -46,3 +51,52 @@ def infer_grid_type(ds):
4651
return "2d-curvilinear"
4752
else:
4853
raise ValueError("unable to infer the grid type")
54+
55+
56+
def infer_cell_geometries(
57+
ds: xr.Dataset,
58+
*,
59+
grid_type: str = "infer",
60+
coords: Literal["infer"] | list[str] = "infer",
61+
):
62+
# TODO: short-circuit for existing geometries
63+
if grid_type == "infer":
64+
grid_type = infer_grid_type(ds)
65+
66+
if grid_type == "2d-crs":
67+
raise NotImplementedError(
68+
"inferring cell geometries is not yet implemented"
69+
" for geotransform-based grids"
70+
)
71+
elif grid_type == "1d-unstructured":
72+
raise ValueError(
73+
"inferring cell geometries is not implemented for unstructured grids."
74+
" This is hard to get right in all cases, so please manually"
75+
" create the geometries."
76+
)
77+
78+
if coords == "infer":
79+
coords = ["longitude", "latitude"]
80+
if any(coord not in ds.cf.coordinates for coord in coords):
81+
raise ValueError(
82+
"cannot infer geographic coordinates. Please add them"
83+
" or explicitly pass the names if they exist."
84+
)
85+
86+
coords_only = ds.cf[coords]
87+
if grid_type == "1d-rectilinear":
88+
coord_names = [ds.cf.coordinates[name][0] for name in coords]
89+
[broadcasted] = xr.broadcast(
90+
coords_only.drop_indexes(coord_names).reset_coords(coord_names)
91+
)
92+
coords_only = broadcasted.set_coords(coord_names)
93+
94+
if any(coord not in coords_only.cf.bounds for coord in coords):
95+
with_bounds = coords_only.cf.add_bounds(coords)
96+
else:
97+
with_bounds = coords_only
98+
99+
bound_names = [with_bounds.cf.bounds[name][0] for name in coords]
100+
boundaries = np.stack([with_bounds.variables[n].data for n in bound_names], axis=-1)
101+
102+
return shapely.polygons(boundaries)

python/tests/test_grids.py

Lines changed: 159 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,130 @@
11
import numpy as np
22
import pytest
3+
import shapely
4+
import shapely.testing
35
import xarray as xr
46

57
from grid_indexing import grids
68

79

8-
class TestInferGridType:
9-
def test_rectilinear_1d(self):
10-
lat = xr.Variable("lat", np.linspace(-10, 10, 3), {"standard_name": "latitude"})
11-
lon = xr.Variable("lon", np.linspace(-5, 5, 4), {"standard_name": "longitude"})
12-
ds = xr.Dataset(coords={"lat": lat, "lon": lon})
13-
14-
actual = grids.infer_grid_type(ds)
15-
assert actual == "1d-rectilinear"
16-
17-
def test_rectilinear_2d(self):
18-
lat_, lon_ = np.meshgrid(np.linspace(-10, 10, 3), np.linspace(-5, 5, 4))
19-
lat = xr.Variable(["y", "x"], lat_, {"standard_name": "latitude"})
20-
lon = xr.Variable(["y", "x"], lon_, {"standard_name": "longitude"})
21-
ds = xr.Dataset(coords={"lat": lat, "lon": lon})
22-
23-
actual = grids.infer_grid_type(ds)
24-
assert actual == "2d-rectilinear"
25-
26-
def test_curvilinear_2d(self):
27-
lat_ = np.array([[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]])
28-
lon_ = np.array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]])
29-
30-
lat = xr.Variable(["y", "x"], lat_, {"standard_name": "latitude"})
31-
lon = xr.Variable(["y", "x"], lon_, {"standard_name": "longitude"})
32-
ds = xr.Dataset(coords={"lat": lat, "lon": lon})
33-
34-
actual = grids.infer_grid_type(ds)
35-
assert actual == "2d-curvilinear"
36-
37-
def test_unstructured_1d(self):
38-
lat = xr.Variable(
39-
"cells", np.linspace(-10, 10, 12), {"standard_name": "latitude"}
40-
)
41-
lon = xr.Variable(
42-
"cells", np.linspace(-5, 5, 12), {"standard_name": "longitude"}
43-
)
44-
ds = xr.Dataset(coords={"lat": lat, "lon": lon})
45-
46-
actual = grids.infer_grid_type(ds)
47-
48-
assert actual == "1d-unstructured"
10+
def example_dataset(grid_type):
11+
match grid_type:
12+
case "1d-rectilinear":
13+
lat_ = np.array([0, 2])
14+
lon_ = np.array([0, 2, 4])
15+
lat = xr.Variable("lat", lat_, {"standard_name": "latitude"})
16+
lon = xr.Variable("lon", lon_, {"standard_name": "longitude"})
17+
ds = xr.Dataset(coords={"lat": lat, "lon": lon})
18+
case "2d-rectilinear":
19+
lat_, lon_ = np.meshgrid(np.array([0, 2]), np.array([0, 2, 4]))
20+
lat = xr.Variable(["y", "x"], lat_, {"standard_name": "latitude"})
21+
lon = xr.Variable(["y", "x"], lon_, {"standard_name": "longitude"})
22+
ds = xr.Dataset(coords={"lat": lat, "lon": lon})
23+
case "2d-curvilinear":
24+
lat_ = np.array([[0, 0, 0], [2, 2, 2]])
25+
lon_ = np.array([[0, 2, 4], [2, 4, 6]])
26+
27+
lat = xr.Variable(["y", "x"], lat_, {"standard_name": "latitude"})
28+
lon = xr.Variable(["y", "x"], lon_, {"standard_name": "longitude"})
29+
ds = xr.Dataset(coords={"lat": lat, "lon": lon})
30+
case "1d-unstructured":
31+
lat_ = np.arange(12)
32+
lon_ = np.arange(-5, 7)
33+
lat = xr.Variable("cells", lat_, {"standard_name": "latitude"})
34+
lon = xr.Variable("cells", lon_, {"standard_name": "longitude"})
35+
ds = xr.Dataset(coords={"lat": lat, "lon": lon})
36+
case "2d-crs":
37+
data = np.linspace(-10, 10, 12).reshape(3, 4)
38+
geo_transform = (
39+
"101985.0 300.0379266750948 0.0 2826915.0 0.0 -300.041782729805"
40+
)
41+
42+
attrs = {
43+
"grid_mapping_name": "transverse_mercator",
44+
"GeoTransform": geo_transform,
45+
}
4946

50-
def test_crs_2d(self):
51-
data = np.linspace(-10, 10, 12).reshape(3, 4)
52-
geo_transform = "101985.0 300.0379266750948 0.0 2826915.0 0.0 -300.041782729805"
47+
ds = xr.Dataset(
48+
{"band_data": (["y", "x"], data)},
49+
coords={"spatial_ref": ((), np.array(0), attrs)},
50+
)
51+
52+
return ds
53+
54+
55+
def example_geometries(grid_type):
56+
if grid_type == "2d-crs":
57+
raise NotImplementedError
58+
59+
match grid_type:
60+
case "1d-rectilinear":
61+
boundaries = np.array(
62+
[
63+
[
64+
[[-1, -1], [-1, 1], [1, 1], [1, -1]],
65+
[[-1, 1], [-1, 3], [1, 3], [1, 1]],
66+
],
67+
[
68+
[[1, -1], [1, 1], [3, 1], [3, -1]],
69+
[[1, 1], [1, 3], [3, 3], [3, 1]],
70+
],
71+
[
72+
[[3, -1], [3, 1], [5, 1], [5, -1]],
73+
[[3, 1], [3, 3], [5, 3], [5, 1]],
74+
],
75+
]
76+
)
77+
case "2d-rectilinear":
78+
boundaries = np.array(
79+
[
80+
[
81+
[[-1, -1], [-1, 1], [1, 1], [1, -1]],
82+
[[-1, 1], [-1, 3], [1, 3], [1, 1]],
83+
],
84+
[
85+
[[1, -1], [1, 1], [3, 1], [3, -1]],
86+
[[1, 1], [1, 3], [3, 3], [3, 1]],
87+
],
88+
[
89+
[[3, -1], [3, 1], [5, 1], [5, -1]],
90+
[[3, 1], [3, 3], [5, 3], [5, 1]],
91+
],
92+
]
93+
)
94+
case "2d-curvilinear":
95+
boundaries = np.array(
96+
[
97+
[
98+
[[-2, -1], [0, -1], [2, 1], [0, 1]],
99+
[[0, -1], [2, -1], [4, 1], [2, 1]],
100+
[[2, -1], [4, -1], [6, 1], [4, 1]],
101+
],
102+
[
103+
[[0, 1], [2, 1], [4, 3], [2, 3]],
104+
[[2, 1], [4, 1], [6, 3], [4, 3]],
105+
[[4, 1], [6, 1], [8, 3], [6, 3]],
106+
],
107+
]
108+
)
109+
110+
return shapely.polygons(boundaries)
53111

54-
ds = xr.Dataset(
55-
{"band_data": (["y", "x"], data)},
56-
coords={
57-
"spatial_ref": (
58-
(),
59-
np.array(0),
60-
{
61-
"grid_mapping_name": "transverse_mercator",
62-
"GeoTransform": geo_transform,
63-
},
64-
)
65-
},
66-
)
67112

113+
class TestInferGridType:
114+
@pytest.mark.parametrize(
115+
"grid_type",
116+
[
117+
"1d-rectilinear",
118+
"2d-rectilinear",
119+
"2d-curvilinear",
120+
"1d-unstructured",
121+
"2d-crs",
122+
],
123+
)
124+
def test_infer_grid_type(self, grid_type):
125+
ds = example_dataset(grid_type)
68126
actual = grids.infer_grid_type(ds)
69-
assert actual == "2d-crs"
127+
assert actual == grid_type
70128

71129
def test_missing_spatial_coordinates(self):
72130
ds = xr.Dataset()
@@ -86,3 +144,46 @@ def test_unknown_grid_type(self):
86144

87145
with pytest.raises(ValueError, match="unable to infer the grid type"):
88146
grids.infer_grid_type(ds)
147+
148+
149+
class TestInferCellGeometries:
150+
@pytest.mark.parametrize(
151+
["grid_type", "error", "pattern"],
152+
(
153+
pytest.param("2d-crs", NotImplementedError, "geotransform", id="2d-crs"),
154+
pytest.param(
155+
"1d-unstructured",
156+
ValueError,
157+
"unstructured grids",
158+
id="1d-unstructured",
159+
),
160+
),
161+
)
162+
def test_not_supported(self, grid_type, error, pattern):
163+
ds = example_dataset(grid_type)
164+
with pytest.raises(error, match=pattern):
165+
grids.infer_cell_geometries(ds)
166+
167+
def test_infer_coords(self):
168+
ds = xr.Dataset()
169+
with pytest.raises(ValueError, match="cannot infer geographic coordinates"):
170+
grids.infer_cell_geometries(ds, grid_type="2d-rectilinear")
171+
172+
@pytest.mark.parametrize(
173+
"grid_type",
174+
[
175+
"1d-rectilinear",
176+
"2d-rectilinear",
177+
"2d-curvilinear",
178+
pytest.param(
179+
"2d-crs", marks=pytest.mark.xfail(reason="not yet implemented")
180+
),
181+
],
182+
)
183+
def test_infer_geoms(self, grid_type):
184+
ds = example_dataset(grid_type)
185+
expected = example_geometries(grid_type)
186+
187+
actual = grids.infer_cell_geometries(ds, grid_type=grid_type)
188+
189+
shapely.testing.assert_geometries_equal(actual, expected)

0 commit comments

Comments
 (0)