Skip to content

Commit 4ac7ebe

Browse files
Add SGRID metadata support for node_coordinates attribute (#2456)
1 parent 3df2b13 commit 4ac7ebe

File tree

3 files changed

+106
-73
lines changed

3 files changed

+106
-73
lines changed

src/parcels/_core/utils/sgrid.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
topology_dimension: Literal[2],
5555
node_dimensions: tuple[Dim, Dim],
5656
face_dimensions: tuple[DimDimPadding, DimDimPadding],
57+
node_coordinates: None | tuple[Dim, Dim] = None,
5758
vertical_dimensions: None | tuple[DimDimPadding] = None,
5859
):
5960
if cf_role != "grid_topology":
@@ -76,6 +77,14 @@ def __init__(
7677
):
7778
raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid")
7879

80+
if node_coordinates is not None:
81+
if not (
82+
isinstance(node_coordinates, tuple)
83+
and len(node_coordinates) == 2
84+
and all(isinstance(nd, str) for nd in node_coordinates)
85+
):
86+
raise ValueError("node_coordinates must be a tuple of 2 dimensions for a 2D grid")
87+
7988
if vertical_dimensions is not None:
8089
if not (
8190
isinstance(vertical_dimensions, tuple)
@@ -90,21 +99,21 @@ def __init__(
9099
self.node_dimensions = node_dimensions
91100
self.face_dimensions = face_dimensions
92101

93-
#! Optional attributes aren't really important to Parcels, can be added later if needed
102+
# Optional attributes
103+
self.node_coordinates = node_coordinates
104+
self.vertical_dimensions = vertical_dimensions
105+
106+
#! Some optional attributes aren't really important to Parcels, can be added later if needed
94107
# Optional attributes
95108
# # With defaults (set in init)
96109
# edge1_dimensions: tuple[Dim, DimDimPadding]
97110
# edge2_dimensions: tuple[DimDimPadding, Dim]
98111

99112
# # Without defaults
100-
# node_coordinates: None | Any = None
101113
# edge1_coordinates: None | Any = None
102114
# edge2_coordinates: None | Any = None
103115
# face_coordinate: None | Any = None
104116

105-
#! Important optional attribute for 2D grids with vertical layering
106-
self.vertical_dimensions = vertical_dimensions
107-
108117
def __repr__(self) -> str:
109118
return repr_from_dunder_dict(self)
110119

@@ -121,6 +130,7 @@ def from_attrs(cls, attrs):
121130
topology_dimension=attrs["topology_dimension"],
122131
node_dimensions=load_mappings(attrs["node_dimensions"]),
123132
face_dimensions=load_mappings(attrs["face_dimensions"]),
133+
node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")),
124134
vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")),
125135
)
126136
except Exception as e:
@@ -133,6 +143,8 @@ def to_attrs(self) -> dict[str, str | int]:
133143
node_dimensions=dump_mappings(self.node_dimensions),
134144
face_dimensions=dump_mappings(self.face_dimensions),
135145
)
146+
if self.node_coordinates is not None:
147+
d["node_coordinates"] = dump_mappings(self.node_coordinates)
136148
if self.vertical_dimensions is not None:
137149
d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions)
138150
return d
@@ -148,6 +160,7 @@ def __init__(
148160
topology_dimension: Literal[3],
149161
node_dimensions: tuple[Dim, Dim, Dim],
150162
volume_dimensions: tuple[DimDimPadding, DimDimPadding, DimDimPadding],
163+
node_coordinates: None | tuple[Dim, Dim, Dim] = None,
151164
):
152165
if cf_role != "grid_topology":
153166
raise ValueError(f"cf_role must be 'grid_topology', got {cf_role!r}")
@@ -169,13 +182,24 @@ def __init__(
169182
):
170183
raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid")
171184

185+
if node_coordinates is not None:
186+
if not (
187+
isinstance(node_coordinates, tuple)
188+
and len(node_coordinates) == 3
189+
and all(isinstance(nd, str) for nd in node_coordinates)
190+
):
191+
raise ValueError("node_coordinates must be a tuple of 3 dimensions for a 3D grid")
192+
172193
# Required attributes
173194
self.cf_role = cf_role
174195
self.topology_dimension = topology_dimension
175196
self.node_dimensions = node_dimensions
176197
self.volume_dimensions = volume_dimensions
177198

178-
# ! Optional attributes aren't really important to Parcels, can be added later if needed
199+
# Optional attributes
200+
self.node_coordinates = node_coordinates
201+
202+
# ! Some optional attributes aren't really important to Parcels, can be added later if needed
179203
# Optional attributes
180204
# # With defaults (set in init)
181205
# edge1_dimensions: tuple[DimDimPadding, Dim, Dim]
@@ -186,7 +210,6 @@ def __init__(
186210
# face3_dimensions: tuple[DimDimPadding, DimDimPadding, Dim]
187211

188212
# # Without defaults
189-
# node_coordinates
190213
# edge *i_coordinates*
191214
# face *i_coordinates*
192215
# volume_coordinates
@@ -207,17 +230,21 @@ def from_attrs(cls, attrs):
207230
topology_dimension=attrs["topology_dimension"],
208231
node_dimensions=load_mappings(attrs["node_dimensions"]),
209232
volume_dimensions=load_mappings(attrs["volume_dimensions"]),
233+
node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")),
210234
)
211235
except Exception as e:
212236
raise SGridParsingException(f"Failed to parse Grid3DMetadata from {attrs=!r}") from e
213237

214238
def to_attrs(self) -> dict[str, str | int]:
215-
return dict(
239+
d = dict(
216240
cf_role=self.cf_role,
217241
topology_dimension=self.topology_dimension,
218242
node_dimensions=dump_mappings(self.node_dimensions),
219243
volume_dimensions=dump_mappings(self.volume_dimensions),
220244
)
245+
if self.node_coordinates is not None:
246+
d["node_coordinates"] = dump_mappings(self.node_coordinates)
247+
return d
221248

222249
def rename_dims(self, dims_dict: dict[str, str]) -> Self:
223250
return _metadata_rename_dims(self, dims_dict)

tests/strategies/sgrid.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
@st.composite
3030
def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata:
31-
N = 6
31+
N = 8
3232
names = draw(st.lists(dimension_name, min_size=N, max_size=N, unique=True))
3333
node_dimension1 = names[0]
3434
node_dimension2 = names[1]
@@ -37,11 +37,20 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata:
3737
padding_type1 = draw(padding)
3838
padding_type2 = draw(padding)
3939

40-
vertical_dimensions_dim1 = names[4]
41-
vertical_dimensions_dim2 = names[5]
40+
node_coordinates_var1 = names[4]
41+
node_coordinates_var2 = names[5]
42+
has_node_coordinates = draw(st.booleans())
43+
44+
vertical_dimensions_dim1 = names[6]
45+
vertical_dimensions_dim2 = names[7]
4246
vertical_dimensions_padding = draw(padding)
4347
has_vertical_dimensions = draw(st.booleans())
4448

49+
if has_node_coordinates:
50+
node_coordinates = (node_coordinates_var1, node_coordinates_var2)
51+
else:
52+
node_coordinates = None
53+
4554
if has_vertical_dimensions:
4655
vertical_dimensions = (
4756
sgrid.DimDimPadding(vertical_dimensions_dim1, vertical_dimensions_dim2, vertical_dimensions_padding),
@@ -57,13 +66,14 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata:
5766
sgrid.DimDimPadding(face_dimension1, node_dimension1, padding_type1),
5867
sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2),
5968
),
69+
node_coordinates=node_coordinates,
6070
vertical_dimensions=vertical_dimensions,
6171
)
6272

6373

6474
@st.composite
6575
def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata:
66-
N = 6
76+
N = 9
6777
names = draw(st.lists(dimension_name, min_size=N, max_size=N, unique=True))
6878
node_dimension1 = names[0]
6979
node_dimension2 = names[1]
@@ -75,6 +85,16 @@ def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata:
7585
padding_type2 = draw(padding)
7686
padding_type3 = draw(padding)
7787

88+
node_coordinates_var1 = names[6]
89+
node_coordinates_var2 = names[7]
90+
node_coordinates_dim3 = names[8]
91+
has_node_coordinates = draw(st.booleans())
92+
93+
if has_node_coordinates:
94+
node_coordinates = (node_coordinates_var1, node_coordinates_var2, node_coordinates_dim3)
95+
else:
96+
node_coordinates = None
97+
7898
return sgrid.Grid3DMetadata(
7999
cf_role="grid_topology",
80100
topology_dimension=3,
@@ -84,6 +104,7 @@ def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata:
84104
sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2),
85105
sgrid.DimDimPadding(face_dimension3, node_dimension3, padding_type3),
86106
),
107+
node_coordinates=node_coordinates,
87108
)
88109

89110

tests/utils/test_sgrid.py

Lines changed: 46 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import numpy as np
24
import pytest
35
import xarray as xr
@@ -7,29 +9,47 @@
79
from parcels._core.utils import sgrid
810
from tests.strategies import sgrid as sgrid_strategies
911

10-
grid2dmetadata = sgrid.Grid2DMetadata(
11-
cf_role="grid_topology",
12-
topology_dimension=2,
13-
node_dimensions=("node_dimension1", "node_dimension2"),
14-
face_dimensions=(
15-
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
16-
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
17-
),
18-
vertical_dimensions=(
19-
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
20-
),
21-
)
2212

23-
grid3dmetadata = sgrid.Grid3DMetadata(
24-
cf_role="grid_topology",
25-
topology_dimension=3,
26-
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
27-
volume_dimensions=(
28-
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
29-
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
30-
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
31-
),
32-
)
13+
def create_example_grid2dmetadata(with_vertical_dimensions: bool, with_node_coordinates: bool):
14+
vertical_dimensions = (
15+
(sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),)
16+
if with_vertical_dimensions
17+
else None
18+
)
19+
node_coordinates = ("node_coordinates_var1", "node_coordinates_var2") if with_node_coordinates else None
20+
21+
return sgrid.Grid2DMetadata(
22+
cf_role="grid_topology",
23+
topology_dimension=2,
24+
node_dimensions=("node_dimension1", "node_dimension2"),
25+
face_dimensions=(
26+
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
27+
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
28+
),
29+
node_coordinates=node_coordinates,
30+
vertical_dimensions=vertical_dimensions,
31+
)
32+
33+
34+
def create_example_grid3dmetadata(with_node_coordinates: bool):
35+
node_coordinates = (
36+
("node_coordinates_var1", "node_coordinates_var2", "node_coordinates_dim3") if with_node_coordinates else None
37+
)
38+
return sgrid.Grid3DMetadata(
39+
cf_role="grid_topology",
40+
topology_dimension=3,
41+
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
42+
volume_dimensions=(
43+
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
44+
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
45+
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
46+
),
47+
node_coordinates=node_coordinates,
48+
)
49+
50+
51+
grid2dmetadata = create_example_grid2dmetadata(with_vertical_dimensions=True, with_node_coordinates=True)
52+
grid3dmetadata = create_example_grid3dmetadata(with_node_coordinates=True)
3353

3454

3555
def dummy_sgrid_ds(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> xr.Dataset:
@@ -225,45 +245,10 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata):
225245
@pytest.mark.parametrize(
226246
"grid",
227247
[
228-
(
229-
sgrid.Grid2DMetadata(
230-
cf_role="grid_topology",
231-
topology_dimension=2,
232-
node_dimensions=("node_dimension1", "node_dimension2"),
233-
face_dimensions=(
234-
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
235-
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
236-
),
237-
vertical_dimensions=(
238-
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
239-
),
240-
)
241-
),
242-
(
243-
sgrid.Grid2DMetadata(
244-
cf_role="grid_topology",
245-
topology_dimension=2,
246-
node_dimensions=("node_dimension1", "node_dimension2"),
247-
face_dimensions=(
248-
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
249-
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
250-
),
251-
vertical_dimensions=None,
252-
)
253-
),
254-
(
255-
sgrid.Grid3DMetadata(
256-
cf_role="grid_topology",
257-
topology_dimension=3,
258-
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
259-
volume_dimensions=(
260-
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
261-
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
262-
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
263-
),
264-
)
265-
),
266-
],
248+
create_example_grid2dmetadata(with_node_coordinates=i, with_vertical_dimensions=j)
249+
for i, j in itertools.product([False, True], [False, True])
250+
]
251+
+ [create_example_grid3dmetadata(with_node_coordinates=i) for i in [False, True]],
267252
)
268253
def test_rename_dims(grid):
269254
dims = sgrid.get_unique_dim_names(grid)

0 commit comments

Comments
 (0)