Skip to content

Commit afc820a

Browse files
Update sgrid renaming tooling to work with node_coordinates (#2466)
1 parent 99e6d18 commit afc820a

File tree

4 files changed

+99
-49
lines changed

4 files changed

+99
-49
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ docs/_build/*
33
docs/_downloads
44
docs/jupyter_execute/*
55
docs/.jupyter_cache/*
6+
docs/reference
67
output
78

89
*.log

src/parcels/_core/utils/sgrid.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def to_attrs(self) -> dict[str, str | int]:
151151
d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions)
152152
return d
153153

154-
def rename_dims(self, dims_dict: dict[str, str]) -> Self:
155-
return _metadata_rename_dims(self, dims_dict)
154+
def rename(self, names_dict: dict[str, str]) -> Self:
155+
return _metadata_rename(self, names_dict)
156156

157157

158158
class Grid3DMetadata(AttrsSerializable):
@@ -248,8 +248,8 @@ def to_attrs(self) -> dict[str, str | int]:
248248
d["node_coordinates"] = dump_mappings(self.node_coordinates)
249249
return d
250250

251-
def rename_dims(self, dims_dict: dict[str, str]) -> Self:
252-
return _metadata_rename_dims(self, dims_dict)
251+
def rename(self, dims_dict: dict[str, str]) -> Self:
252+
return _metadata_rename(self, dims_dict)
253253

254254

255255
@dataclass
@@ -418,22 +418,22 @@ def parse_sgrid(ds: xr.Dataset):
418418
return (ds, {"coords": xgcm_coords})
419419

420420

421-
def rename_dims(ds: xr.Dataset, dims_dict: dict[str, str]) -> xr.Dataset:
421+
def rename(ds: xr.Dataset, name_dict: dict[str, str]) -> xr.Dataset:
422422
grid_da = get_grid_topology(ds)
423423
if grid_da is None:
424424
raise ValueError(
425425
"No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions."
426426
)
427427

428-
ds = ds.rename_dims(dims_dict)
428+
ds = ds.rename(name_dict)
429429

430430
# Update the metadata
431431
grid = parse_grid_attrs(grid_da.attrs)
432-
ds[grid_da.name].attrs = grid.rename_dims(dims_dict).to_attrs()
432+
ds[grid_da.name].attrs = grid.rename(name_dict).to_attrs()
433433
return ds
434434

435435

436-
def get_unique_dim_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
436+
def get_unique_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
437437
dims = set()
438438
dims.update(set(grid.node_dimensions))
439439

@@ -453,14 +453,6 @@ def get_unique_dim_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
453453
return dims
454454

455455

456-
@overload
457-
def _metadata_rename_dims(grid: Grid2DMetadata, dims_dict: dict[str, str]) -> Grid2DMetadata: ...
458-
459-
460-
@overload
461-
def _metadata_rename_dims(grid: Grid3DMetadata, dims_dict: dict[str, str]) -> Grid3DMetadata: ...
462-
463-
464456
def _attach_sgrid_metadata(ds, grid: Grid2DMetadata | Grid3DMetadata):
465457
"""Copies the dataset and attaches the SGRID metadata in 'grid' variable. Modifies 'conventions' attribute."""
466458
ds = ds.copy()
@@ -473,24 +465,32 @@ def _attach_sgrid_metadata(ds, grid: Grid2DMetadata | Grid3DMetadata):
473465
return ds
474466

475467

476-
def _metadata_rename_dims(grid, dims_dict):
468+
@overload
469+
def _metadata_rename(grid: Grid2DMetadata, names_dict: dict[str, str]) -> Grid2DMetadata: ...
470+
471+
472+
@overload
473+
def _metadata_rename(grid: Grid3DMetadata, names_dict: dict[str, str]) -> Grid3DMetadata: ...
474+
475+
476+
def _metadata_rename(grid, names_dict):
477477
"""
478-
Renames dimensions in SGrid metadata.
478+
Renames dimensions and coordinates in SGrid metadata.
479479
480-
Similar in API to xr.Dataset.rename_dims. Renames dimensions according to dims_dict mapping
480+
Similar in API to xr.Dataset.rename . Renames dimensions according to names_dict mapping
481481
of old dimension names to new dimension names.
482482
"""
483-
dims_dict = dims_dict.copy()
484-
assert len(dims_dict) == len(set(dims_dict.values())), "dims_dict contains duplicate target dimension names"
483+
names_dict = names_dict.copy()
484+
assert len(names_dict) == len(set(names_dict.values())), "names_dict contains duplicate target dimension names"
485485

486-
existing_dims = get_unique_dim_names(grid)
487-
for dim in dims_dict.keys():
488-
if dim not in existing_dims:
489-
raise ValueError(f"Dimension {dim!r} not found in SGrid metadata dimensions {existing_dims!r}")
486+
existing_names = get_unique_names(grid)
487+
for name in names_dict.keys():
488+
if name not in existing_names:
489+
raise ValueError(f"Name {name!r} not found in names defined in SGrid metadata {existing_names!r}")
490490

491-
for dim in existing_dims:
492-
if dim not in dims_dict:
493-
dims_dict[dim] = dim # identity mapping for dimensions not being renamed
491+
for name in existing_names:
492+
if name not in names_dict:
493+
names_dict[name] = name # identity mapping for names not being renamed
494494

495495
kwargs = {}
496496
for key, value in grid.__dict__.items():
@@ -499,14 +499,14 @@ def _metadata_rename_dims(grid, dims_dict):
499499
for item in value:
500500
if isinstance(item, DimDimPadding):
501501
new_item = DimDimPadding(
502-
dim1=dims_dict[item.dim1],
503-
dim2=dims_dict[item.dim2],
502+
dim1=names_dict[item.dim1],
503+
dim2=names_dict[item.dim2],
504504
padding=item.padding,
505505
)
506506
new_value.append(new_item)
507507
else:
508508
assert isinstance(item, str)
509-
new_value.append(dims_dict[item])
509+
new_value.append(names_dict[item])
510510
kwargs[key] = tuple(new_value)
511511
continue
512512

@@ -515,7 +515,7 @@ def _metadata_rename_dims(grid, dims_dict):
515515
continue
516516

517517
if isinstance(value, str):
518-
kwargs[key] = dims_dict[value]
518+
kwargs[key] = names_dict[value]
519519
continue
520520

521521
raise ValueError(f"Unexpected attribute {key!r} on {grid!r}")

src/parcels/_datasets/structured/generic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
_attach_sgrid_metadata,
99
)
1010
from parcels._core.utils.sgrid import (
11-
rename_dims as sgrid_rename_dims,
11+
rename as sgrid_rename,
1212
)
1313
from parcels._datasets.utils import _attach_sgrid_metadata
1414

@@ -258,11 +258,12 @@ def _unrolled_cone_curvilinear_grid():
258258
DimDimPadding("XC", "XG", Padding.HIGH),
259259
DimDimPadding("YC", "YG", Padding.HIGH),
260260
),
261+
node_coordinates=("lon", "lat"),
261262
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.HIGH),),
262263
),
263264
)
264265
.pipe(
265-
sgrid_rename_dims,
266+
sgrid_rename,
266267
_COMODO_TO_2D_SGRID,
267268
)
268269
),
@@ -278,11 +279,12 @@ def _unrolled_cone_curvilinear_grid():
278279
DimDimPadding("XC", "XG", Padding.LOW),
279280
DimDimPadding("YC", "YG", Padding.LOW),
280281
),
282+
node_coordinates=("lon", "lat"),
281283
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.LOW),),
282284
),
283285
)
284286
.pipe(
285-
sgrid_rename_dims,
287+
sgrid_rename,
286288
_COMODO_TO_2D_SGRID,
287289
)
288290
),

tests/utils/test_sgrid.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def dummy_sgrid_2d_ds(grid: sgrid.Grid2DMetadata) -> xr.Dataset:
6565
ds = dummy_comodo_3d_ds()
6666

6767
# Can't rename dimensions that already exist in the dataset
68-
assume(sgrid.get_unique_dim_names(grid) & set(ds.dims) == set())
68+
assume(sgrid.get_unique_names(grid) & set(ds.dims) == set())
6969

7070
renamings = {}
7171
if grid.vertical_dimensions is None:
@@ -90,7 +90,7 @@ def dummy_sgrid_3d_ds(grid: sgrid.Grid3DMetadata) -> xr.Dataset:
9090
ds = dummy_comodo_3d_ds()
9191

9292
# Can't rename dimensions that already exist in the dataset
93-
assume(sgrid.get_unique_dim_names(grid) & set(ds.dims) == set())
93+
assume(sgrid.get_unique_names(grid) & set(ds.dims) == set())
9494

9595
renamings = {}
9696
for old, new in zip(["XG", "YG", "ZG"], grid.node_dimensions, strict=True):
@@ -250,30 +250,77 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata):
250250
]
251251
+ [create_example_grid3dmetadata(with_node_coordinates=i) for i in [False, True]],
252252
)
253-
def test_rename_dims(grid):
254-
dims = sgrid.get_unique_dim_names(grid)
253+
def test_rename(grid):
254+
dims = sgrid.get_unique_names(grid)
255255
dims_dict = {dim: f"new_{dim}" for dim in dims}
256256
dims_dict_inv = {v: k for k, v in dims_dict.items()}
257257

258-
grid_new = grid.rename_dims(dims_dict)
259-
assert dims & set(sgrid.get_unique_dim_names(grid_new)) == set()
258+
grid_new = grid.rename(dims_dict)
259+
assert dims & set(sgrid.get_unique_names(grid_new)) == set()
260260

261-
assert grid == grid_new.rename_dims(dims_dict_inv)
261+
assert grid == grid_new.rename(dims_dict_inv)
262262

263263

264-
def test_rename_dims_errors():
264+
def test_rename_errors():
265265
# Test various error modes of rename_dims
266266
grid = grid2dmetadata
267267
# Non-unique target dimension names
268-
dims_dict = {
268+
names_dict = {
269269
"node_dimension1": "new_node_dimension",
270270
"node_dimension2": "new_node_dimension",
271271
}
272-
with pytest.raises(AssertionError, match="dims_dict contains duplicate target dimension names"):
273-
grid.rename_dims(dims_dict)
272+
with pytest.raises(AssertionError, match="names_dict contains duplicate target dimension names"):
273+
grid.rename(names_dict)
274274
# Unexpected attribute in dims_dict
275-
dims_dict = {
275+
names_dict = {
276276
"unexpected_dimension": "new_unexpected_dimension",
277277
}
278-
with pytest.raises(ValueError, match="Dimension 'unexpected_dimension' not found in SGrid metadata dimensions"):
279-
grid.rename_dims(dims_dict)
278+
with pytest.raises(ValueError, match="Name 'unexpected_dimension' not found in names defined in SGrid metadata"):
279+
grid.rename(names_dict)
280+
281+
282+
@pytest.mark.parametrize(
283+
"ds",
284+
[
285+
xr.Dataset(
286+
{
287+
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(10, 10, 10, 10)),
288+
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(10, 10, 10, 10)),
289+
"grid": (
290+
[],
291+
np.array(0),
292+
sgrid.Grid2DMetadata(
293+
cf_role="grid_topology",
294+
topology_dimension=2,
295+
node_dimensions=("XG", "YG"),
296+
face_dimensions=(
297+
sgrid.DimDimPadding("XC", "XG", sgrid.Padding.HIGH),
298+
sgrid.DimDimPadding("YC", "YG", sgrid.Padding.HIGH),
299+
),
300+
vertical_dimensions=(sgrid.DimDimPadding("ZC", "ZG", sgrid.Padding.HIGH),),
301+
node_coordinates=("lon", "lat"),
302+
).to_attrs(),
303+
),
304+
},
305+
coords={
306+
"lon": (["XG"], 2 * np.pi / 10 * np.arange(0, 10)),
307+
"lat": (["YG"], 2 * np.pi / (10) * np.arange(0, 10)),
308+
"depth": (["ZG"], np.arange(10)),
309+
"time": (["time"], xr.date_range("2000", "2001", 10), {"axis": "T"}),
310+
},
311+
),
312+
],
313+
)
314+
def test_rename_dataset(ds):
315+
# Check renaming works for coordinates
316+
ds_new = sgrid.rename(ds, {"lon": "lon_updated"})
317+
grid_new = sgrid.parse_grid_attrs(ds_new["grid"].attrs)
318+
assert "lon_updated" in ds_new.coords
319+
assert "lon_updated" == grid_new.node_coordinates[0]
320+
321+
# Check renaming works for dim
322+
ds_new = sgrid.rename(ds, {"XC": "XC_updated"})
323+
grid_new = sgrid.parse_grid_attrs(ds_new["grid"].attrs)
324+
assert "XC_updated" in ds_new.dims
325+
assert "XC" not in ds_new.dims
326+
assert "XC_updated" == grid_new.face_dimensions[0].dim1

0 commit comments

Comments
 (0)