1- from collections .abc import Hashable , Mapping
1+ from collections .abc import Hashable , Mapping , Sequence
22from functools import cached_property
33from typing import Literal , cast
44
1010from parcels ._index_search import _search_indices_curvilinear_2d
1111from parcels .basegrid import BaseGrid
1212
13- _XGRID_AXES_ORDERING = "ZYX"
1413_XGRID_AXES = Literal ["X" , "Y" , "Z" ]
14+ _XGRID_AXES_ORDERING : Sequence [_XGRID_AXES ] = "ZYX"
1515
1616_XGCM_AXIS_DIRECTION = Literal ["X" , "Y" , "Z" , "T" ]
1717_XGCM_AXIS_POSITION = Literal ["center" , "left" , "right" , "inner" , "outer" ]
1818_XGCM_AXES = Mapping [_XGCM_AXIS_DIRECTION , xgcm .Axis ]
1919
20+ _FIELD_DATA_ORDERING : Sequence [_XGCM_AXIS_DIRECTION ] = "TZYX"
21+
22+ _DEFAULT_XGCM_KWARGS = {"periodic" : False }
23+
2024
2125def get_cell_count_along_dim (axis : xgcm .Axis ) -> int :
2226 first_coord = list (axis .coords .items ())[0 ]
@@ -34,6 +38,48 @@ def _get_xgrid_axes(grid: xgcm.Grid) -> list[_XGRID_AXES]:
3438 return sorted (spatial_axes , key = _XGRID_AXES_ORDERING .index )
3539
3640
41+ def _drop_field_data (ds : xr .Dataset ) -> xr .Dataset :
42+ """
43+ Removes DataArrays from the dataset that are associated with field data so that
44+ when passed to the XGCM grid, the object only functions as an in memory representation
45+ of the grid.
46+ """
47+ return ds .drop_vars (ds .data_vars )
48+
49+
50+ def _transpose_xfield_data_to_tzyx (da : xr .DataArray , xgcm_grid : xgcm .Grid ) -> xr .DataArray :
51+ """
52+ Transpose a DataArray of any shape into a 4D array of order TZYX. Uses xgcm to determine
53+ the axes, and inserts mock dimensions of size 1 for any axes not present in the DataArray.
54+ """
55+ ax_dims = [(get_axis_from_dim_name (xgcm_grid .axes , dim ), dim ) for dim in da .dims ]
56+
57+ if all (ax_dim [0 ] is None for ax_dim in ax_dims ):
58+ # Assuming its a 1D constant field (hence has no axes)
59+ assert da .shape == (1 , 1 , 1 , 1 )
60+ return da .rename ({old_dim : f"mock{ axis } " for old_dim , axis in zip (da .dims , _FIELD_DATA_ORDERING , strict = True )})
61+
62+ # All dimensions must be associated with an axis in the grid
63+ if any (ax_dim [0 ] is None for ax_dim in ax_dims ):
64+ raise ValueError (
65+ f"DataArray { da .name !r} with dims { da .dims } has dimensions that are not associated with a direction on the provided grid."
66+ )
67+
68+ axes_not_in_field = set (_FIELD_DATA_ORDERING ) - set (ax_dim [0 ] for ax_dim in ax_dims )
69+
70+ mock_dims_to_create = {}
71+ for ax in axes_not_in_field :
72+ mock_dims_to_create [f"mock{ ax } " ] = 1
73+ ax_dims .append ((ax , f"mock{ ax } " ))
74+
75+ if mock_dims_to_create :
76+ da = da .expand_dims (mock_dims_to_create , create_index_for_new_dim = False )
77+
78+ ax_dims = sorted (ax_dims , key = lambda x : _FIELD_DATA_ORDERING .index (x [0 ]))
79+
80+ return da .transpose (* [ax_dim [1 ] for ax_dim in ax_dims ])
81+
82+
3783class XGrid (BaseGrid ):
3884 """
3985 Class to represent a structured grid in Parcels. Wraps a xgcm-like Grid object (we use a trimmed down version of the xgcm.Grid class that is vendored with Parcels).
@@ -53,6 +99,18 @@ def __init__(self, grid: xgcm.Grid, mesh="flat"):
5399 if len (set (grid .axes ) & {"X" , "Y" , "Z" }) > 0 : # Only if spatial grid is >0D (see #2054 for further development)
54100 assert_valid_lat_lon (ds ["lat" ], ds ["lon" ], grid .axes )
55101
102+ @classmethod
103+ def from_dataset (cls , ds : xr .Dataset , mesh = "flat" , xgcm_kwargs = None ):
104+ """WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release
105+ if xgcm_kwargs is None :
106+ xgcm_kwargs = {}
107+
108+ xgcm_kwargs = {** _DEFAULT_XGCM_KWARGS , ** xgcm_kwargs }
109+
110+ ds = _drop_field_data (ds )
111+ grid = xgcm .Grid (ds , ** xgcm_kwargs )
112+ return cls (grid , mesh = mesh )
113+
56114 @property
57115 def axes (self ) -> list [_XGRID_AXES ]:
58116 return _get_xgrid_axes (self .xgcm_grid )
0 commit comments