Skip to content

Commit 882a640

Browse files
committed
Update API for convert.nemo_to_sgrid
Updates the API for conversion to be more closely aligned with the input data. Also handles the U and V fields separately - correctly assigning the dimension naming before merging into a single dataset.
1 parent d3ae295 commit 882a640

File tree

2 files changed

+55
-28
lines changed

2 files changed

+55
-28
lines changed

src/parcels/convert.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@
3333
"W": ["upward_sea_water_velocity", "vertical_sea_water_velocity"],
3434
}
3535

36-
_NEMO_DIMENSION_NAMES = ["x", "y", "time", "glamf", "gphif", "depth"]
36+
_NEMO_DIMENSION_COORD_NAMES = ["x", "y", "time", "x", "x_center", "y", "y_center", "depth", "glamf", "gphif"]
3737

3838
_NEMO_AXIS_VARNAMES = {
39-
"X": "glamf",
40-
"Y": "gphif",
41-
"Z": "depth",
42-
"T": "time",
39+
"x": "X",
40+
"x_center": "X",
41+
"y": "Y",
42+
"y_center": "Y",
43+
"depth": "Z",
44+
"time": "T",
4345
}
4446

4547
_NEMO_VARNAMES_MAPPING = {
@@ -89,12 +91,12 @@ def _assign_dims_as_coords(ds, dimension_names):
8991
return ds
9092

9193

92-
def _drop_unused_dimensions_and_coords(ds, dimension_names):
94+
def _drop_unused_dimensions_and_coords(ds, dimension_and_coord_names):
9395
for dim in ds.dims:
94-
if dim not in dimension_names:
96+
if dim not in dimension_and_coord_names:
9597
ds = ds.drop_dims(dim, errors="ignore")
9698
for coord in ds.coords:
97-
if coord not in dimension_names:
99+
if coord not in dimension_and_coord_names:
98100
ds = ds.drop_vars(coord, errors="ignore")
99101
return ds
100102

@@ -113,10 +115,9 @@ def _maybe_remove_depth_from_lonlat(ds):
113115
return ds
114116

115117

116-
def _set_axis_attrs(ds, axis_varnames):
117-
for axis, varname in axis_varnames.items():
118-
if varname in ds.coords:
119-
ds[varname].attrs["axis"] = axis
118+
def _set_axis_attrs(ds, dim_axis):
119+
for dim, axis in dim_axis.items():
120+
ds[dim].attrs["axis"] = axis
120121
return ds
121122

122123

@@ -162,16 +163,45 @@ def _discover_U_and_V(ds: xr.Dataset, cf_standard_names_fallbacks) -> xr.Dataset
162163
return ds
163164

164165

165-
def nemo_to_sgrid(ds: xr.Dataset):
166-
ds = ds.copy()
166+
def nemo_to_sgrid(*, coords: xr.Dataset, **fields: dict[str, xr.Dataset]):
167+
fields = fields.copy()
168+
coords = coords[["gphif", "glamf"]]
169+
170+
for name, field_da in fields.items():
171+
if isinstance(field_da, xr.Dataset):
172+
field_da = field_da[name]
173+
# TODO: logging message, warn if multiple fields are in this dataset
174+
175+
match name:
176+
case "U":
177+
field_da = field_da.rename({"y": "y_center"})
178+
case "V":
179+
field_da = field_da.rename({"x": "x_center"})
180+
case _:
181+
pass
182+
183+
fields[name] = field_da
184+
185+
if "time" in coords.dims:
186+
if coords.dims["time"] != 1:
187+
raise ValueError("Time dimension in coords must be length 1 (i.e., no time-varying grid).")
188+
coords = coords.isel(time=0).drop("time")
189+
if len(coords.dims) == 3:
190+
for dim, len_ in coords.dims.items():
191+
if len_ == 1:
192+
# TODO: log statement about selecting along z dim of 1
193+
coords = coords.isel({dim: 0})
194+
if len(coords.dims) != 2:
195+
raise ValueError("Expected coordsinates to be 2 dimensional")
196+
197+
ds = xr.merge(list(fields.values()) + [coords])
167198
ds = _maybe_rename_variables(ds, _NEMO_VARNAMES_MAPPING)
168199
ds = _discover_U_and_V(ds, _NEMO_CF_STANDARD_NAME_FALLBACKS)
169200
ds = _maybe_create_depth_dim(ds)
170201
ds = _maybe_bring_UV_depths_to_depth(ds)
171-
ds = _drop_unused_dimensions_and_coords(ds, _NEMO_DIMENSION_NAMES)
172-
ds = _maybe_rename_coords(ds, _NEMO_AXIS_VARNAMES)
173-
ds = _assign_dims_as_coords(ds, _NEMO_DIMENSION_NAMES)
174-
ds = _set_coords(ds, _NEMO_DIMENSION_NAMES)
202+
ds = _drop_unused_dimensions_and_coords(ds, _NEMO_DIMENSION_COORD_NAMES)
203+
ds = _assign_dims_as_coords(ds, _NEMO_DIMENSION_COORD_NAMES)
204+
ds = _set_coords(ds, _NEMO_DIMENSION_COORD_NAMES)
175205
ds = _maybe_remove_depth_from_lonlat(ds)
176206
ds = _set_axis_attrs(ds, _NEMO_AXIS_VARNAMES)
177207

tests/test_convert.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77

88
def test_nemo_to_sgrid():
99
data_folder = parcels.download_example_dataset("NemoCurvilinear_data")
10-
ds_fields = xr.open_mfdataset(
11-
data_folder.glob("*.nc4"),
12-
data_vars="minimal",
13-
coords="minimal",
14-
compat="override",
15-
)
16-
ds_fields = ds_fields.isel(time=0, z_a=0, z=0, drop=True)
17-
ds = convert.nemo_to_sgrid(ds_fields)
10+
U = xr.open_mfdataset(data_folder.glob("*U.nc4"))
11+
V = xr.open_mfdataset(data_folder.glob("*V.nc4"))
12+
coords = xr.open_dataset(data_folder / "mesh_mask.nc4")
13+
14+
ds = convert.nemo_to_sgrid(U=U, V=V, coords=coords)
1815

1916
assert ds["grid"].attrs == {
2017
"cf_role": "grid_topology",
@@ -32,8 +29,8 @@ def test_nemo_to_sgrid():
3229
assert {
3330
meta.get_value_by_id("node_dimension1"), # X edge
3431
meta.get_value_by_id("face_dimension2"), # Y center
35-
} in set(ds["U"].dims)
32+
}.issubset(set(ds["U"].dims))
3633
assert {
3734
meta.get_value_by_id("face_dimension1"), # X center
3835
meta.get_value_by_id("node_dimension2"), # Y edge
39-
} in set(ds["V"].dims)
36+
}.issubset(set(ds["V"].dims))

0 commit comments

Comments
 (0)