Skip to content

Commit 0158fe4

Browse files
Merge branch 'v4-dev' into particledata_as_dict
2 parents f681162 + e576d39 commit 0158fe4

File tree

12 files changed

+183
-61
lines changed

12 files changed

+183
-61
lines changed

parcels/field.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
_raise_field_out_of_bound_error,
3030
)
3131
from parcels.uxgrid import UxGrid
32-
from parcels.xgrid import XGrid
32+
from parcels.xgrid import XGrid, _transpose_xfield_data_to_tzyx
3333

3434
from ._index_search import _search_time_index
3535

@@ -146,6 +146,9 @@ def __init__(
146146

147147
_assert_compatible_combination(data, grid)
148148

149+
if isinstance(grid, XGrid):
150+
data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid)
151+
149152
self.name = name
150153
self.data = data
151154
self.grid = grid
@@ -186,8 +189,9 @@ def __init__(
186189
else:
187190
raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'")
188191

189-
if "time" not in self.data.dims:
190-
raise ValueError("Field is missing a 'time' dimension. ")
192+
if self.data.shape[0] > 1:
193+
if "time" not in self.data.coords:
194+
raise ValueError("Field data is missing a 'time' coordinate.")
191195

192196
@property
193197
def units(self):
@@ -439,7 +443,7 @@ def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: ux
439443

440444

441445
def _get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | None:
442-
if len(data.time) == 1:
446+
if data.shape[0] == 1:
443447
return None
444448

445449
return TimeInterval(data.time.values[0], data.time.values[-1])

parcels/fieldset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
132132
"""
133133
da = xr.DataArray(
134134
data=np.full((1, 1, 1, 1), value),
135-
dims=["time", "ZG", "YG", "XG"],
136135
)
137136
grid = XGrid(xgcm.Grid(da))
138137
self.add_field(

parcels/xgrid.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Hashable, Mapping
1+
from collections.abc import Hashable, Mapping, Sequence
22
from functools import cached_property
33
from typing import Literal, cast
44

@@ -10,13 +10,17 @@
1010
from parcels._index_search import _search_indices_curvilinear_2d
1111
from parcels.basegrid import BaseGrid
1212

13-
_XGRID_AXES_ORDERING = "ZYX"
1413
_XGRID_AXES = Literal["X", "Y", "Z"]
14+
_XGRID_AXES_ORDERING: Sequence[_XGRID_AXES] = "ZYX"
1515

1616
_XGCM_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"]
1717
_XGCM_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"]
1818
_XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis]
1919

20+
_FIELD_DATA_ORDERING: Sequence[_XGCM_AXIS_DIRECTION] = "TZYX"
21+
22+
_DEFAULT_XGCM_KWARGS = {"periodic": False}
23+
2024

2125
def get_cell_count_along_dim(axis: xgcm.Axis) -> int:
2226
first_coord = list(axis.coords.items())[0]
@@ -34,6 +38,48 @@ def _get_xgrid_axes(grid: xgcm.Grid) -> list[_XGRID_AXES]:
3438
return sorted(spatial_axes, key=_XGRID_AXES_ORDERING.index)
3539

3640

41+
def _drop_field_data(ds: xr.Dataset) -> xr.Dataset:
42+
"""
43+
Removes DataArrays from the dataset that are associated with field data so that
44+
when passed to the XGCM grid, the object only functions as an in memory representation
45+
of the grid.
46+
"""
47+
return ds.drop_vars(ds.data_vars)
48+
49+
50+
def _transpose_xfield_data_to_tzyx(da: xr.DataArray, xgcm_grid: xgcm.Grid) -> xr.DataArray:
51+
"""
52+
Transpose a DataArray of any shape into a 4D array of order TZYX. Uses xgcm to determine
53+
the axes, and inserts mock dimensions of size 1 for any axes not present in the DataArray.
54+
"""
55+
ax_dims = [(get_axis_from_dim_name(xgcm_grid.axes, dim), dim) for dim in da.dims]
56+
57+
if all(ax_dim[0] is None for ax_dim in ax_dims):
58+
# Assuming its a 1D constant field (hence has no axes)
59+
assert da.shape == (1, 1, 1, 1)
60+
return da.rename({old_dim: f"mock{axis}" for old_dim, axis in zip(da.dims, _FIELD_DATA_ORDERING, strict=True)})
61+
62+
# All dimensions must be associated with an axis in the grid
63+
if any(ax_dim[0] is None for ax_dim in ax_dims):
64+
raise ValueError(
65+
f"DataArray {da.name!r} with dims {da.dims} has dimensions that are not associated with a direction on the provided grid."
66+
)
67+
68+
axes_not_in_field = set(_FIELD_DATA_ORDERING) - set(ax_dim[0] for ax_dim in ax_dims)
69+
70+
mock_dims_to_create = {}
71+
for ax in axes_not_in_field:
72+
mock_dims_to_create[f"mock{ax}"] = 1
73+
ax_dims.append((ax, f"mock{ax}"))
74+
75+
if mock_dims_to_create:
76+
da = da.expand_dims(mock_dims_to_create, create_index_for_new_dim=False)
77+
78+
ax_dims = sorted(ax_dims, key=lambda x: _FIELD_DATA_ORDERING.index(x[0]))
79+
80+
return da.transpose(*[ax_dim[1] for ax_dim in ax_dims])
81+
82+
3783
class XGrid(BaseGrid):
3884
"""
3985
Class to represent a structured grid in Parcels. Wraps a xgcm-like Grid object (we use a trimmed down version of the xgcm.Grid class that is vendored with Parcels).
@@ -53,6 +99,18 @@ def __init__(self, grid: xgcm.Grid, mesh="flat"):
5399
if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development)
54100
assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes)
55101

102+
@classmethod
103+
def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
104+
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release
105+
if xgcm_kwargs is None:
106+
xgcm_kwargs = {}
107+
108+
xgcm_kwargs = {**_DEFAULT_XGCM_KWARGS, **xgcm_kwargs}
109+
110+
ds = _drop_field_data(ds)
111+
grid = xgcm.Grid(ds, **xgcm_kwargs)
112+
return cls(grid, mesh=mesh)
113+
56114
@property
57115
def axes(self) -> list[_XGRID_AXES]:
58116
return _get_xgrid_axes(self.xgcm_grid)

tests/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
"""General helper functions and utilies for test suite."""
22

3+
from __future__ import annotations
4+
35
from pathlib import Path
6+
from typing import TYPE_CHECKING
47

58
import numpy as np
69
import xarray as xr
710

811
import parcels
912
from parcels import FieldSet
13+
from parcels.xgrid import _FIELD_DATA_ORDERING, get_axis_from_dim_name
14+
15+
if TYPE_CHECKING:
16+
from parcels.xgrid import XGrid
1017

1118
PROJECT_ROOT = Path(__file__).resolve().parents[1]
1219
TEST_ROOT = PROJECT_ROOT / "tests"
@@ -116,3 +123,13 @@ def create_fieldset_zeros_simple(xdim=40, ydim=100, withtime=False):
116123

117124
def assert_empty_folder(path: Path):
118125
assert [p.name for p in path.iterdir()] == []
126+
127+
128+
def assert_valid_field_data(data: xr.DataArray, grid: XGrid):
129+
assert len(data.shape) == 4, f"Field data should have 4 dimensions (time, depth, lat, lon), got dims {data.dims}"
130+
131+
for ax_expected, dim in zip(_FIELD_DATA_ORDERING, data.dims, strict=True):
132+
ax_actual = get_axis_from_dim_name(grid.xgcm_grid.axes, dim)
133+
if ax_actual is None:
134+
continue # None is ok
135+
assert ax_actual == ax_expected, f"Expected axis {ax_expected} for dimension '{dim}', got {ax_actual}"

tests/v4/test_datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from parcels import xgcm
12
from parcels._datasets.structured.generic import datasets
2-
from parcels.xgcm import Grid
33

44

55
def test_left_indexed_dataset():
66
"""Checks that 'ds_2d_left' is right indexed on all variables."""
77
ds = datasets["ds_2d_left"]
8-
grid = Grid(ds)
8+
grid = xgcm.Grid(ds)
99

1010
for _axis_name, axis in grid.axes.items():
1111
for pos, _dim_name in axis.coords.items():
@@ -15,7 +15,7 @@ def test_left_indexed_dataset():
1515
def test_right_indexed_dataset():
1616
"""Checks that 'ds_2d_right' is right indexed on all variables."""
1717
ds = datasets["ds_2d_right"]
18-
grid = Grid(ds)
18+
grid = xgcm.Grid(ds)
1919
for _axis_name, axis in grid.axes.items():
2020
for pos, _dim_name in axis.coords.items():
2121
assert pos in ["center", "right"]

tests/v4/test_field.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uxarray as ux
66
import xarray as xr
77

8-
from parcels import Field, UXPiecewiseConstantFace, UXPiecewiseLinearNode, VectorField, xgcm
8+
from parcels import Field, UXPiecewiseConstantFace, UXPiecewiseLinearNode, VectorField
99
from parcels._datasets.structured.generic import T as T_structured
1010
from parcels._datasets.structured.generic import datasets as datasets_structured
1111
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
@@ -15,7 +15,7 @@
1515

1616
def test_field_init_param_types():
1717
data = datasets_structured["ds_2d_left"]
18-
grid = XGrid(xgcm.Grid(data))
18+
grid = XGrid.from_dataset(data)
1919
with pytest.raises(ValueError, match="Expected `name` to be a string"):
2020
Field(name=123, data=data["data_g"], grid=grid)
2121

@@ -32,7 +32,7 @@ def test_field_init_param_types():
3232
@pytest.mark.parametrize(
3333
"data,grid",
3434
[
35-
pytest.param(ux.UxDataArray(), XGrid(xgcm.Grid(datasets_structured["ds_2d_left"])), id="uxdata-grid"),
35+
pytest.param(ux.UxDataArray(), XGrid.from_dataset(datasets_structured["ds_2d_left"]), id="uxdata-grid"),
3636
pytest.param(
3737
xr.DataArray(),
3838
UxGrid(
@@ -57,7 +57,7 @@ def test_field_incompatible_combination(data, grid):
5757
[
5858
pytest.param(
5959
datasets_structured["ds_2d_left"]["data_g"],
60-
XGrid(xgcm.Grid(datasets_structured["ds_2d_left"])),
60+
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
6161
id="ds_2d_left",
6262
), # TODO: Perhaps this test should be expanded to cover more datasets?
6363
],
@@ -80,10 +80,10 @@ def test_field_init_fail_on_float_time_dim():
8080
(users are expected to use timedelta64 or datetime).
8181
"""
8282
ds = datasets_structured["ds_2d_left"].copy()
83-
ds["time"] = np.arange(0, T_structured, dtype="float64")
83+
ds["time"] = (ds["time"].dims, np.arange(0, T_structured, dtype="float64"), ds["time"].attrs)
8484

8585
data = ds["data_g"]
86-
grid = XGrid(xgcm.Grid(ds))
86+
grid = XGrid.from_dataset(ds)
8787
with pytest.raises(
8888
ValueError,
8989
match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?",
@@ -100,7 +100,7 @@ def test_field_init_fail_on_float_time_dim():
100100
[
101101
pytest.param(
102102
datasets_structured["ds_2d_left"]["data_g"],
103-
XGrid(xgcm.Grid(datasets_structured["ds_2d_left"])),
103+
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
104104
id="ds_2d_left",
105105
),
106106
],
@@ -119,7 +119,7 @@ def test_vectorfield_init_different_time_intervals():
119119

120120
def test_field_invalid_interpolator():
121121
ds = datasets_structured["ds_2d_left"]
122-
grid = XGrid(xgcm.Grid(ds))
122+
grid = XGrid.from_dataset(ds)
123123

124124
def invalid_interpolator_wrong_signature(self, ti, position, tau, t, z, y, invalid):
125125
return 0.0
@@ -131,7 +131,7 @@ def invalid_interpolator_wrong_signature(self, ti, position, tau, t, z, y, inval
131131

132132
def test_vectorfield_invalid_interpolator():
133133
ds = datasets_structured["ds_2d_left"]
134-
grid = XGrid(xgcm.Grid(ds))
134+
grid = XGrid.from_dataset(ds)
135135

136136
def invalid_interpolator_wrong_signature(self, ti, position, tau, t, z, y, invalid):
137137
return 0.0

0 commit comments

Comments
 (0)