Skip to content

Commit b7bdddc

Browse files
committed
Update XGrid axis dim to use cell count
1 parent e1f028c commit b7bdddc

File tree

3 files changed

+10
-12
lines changed

3 files changed

+10
-12
lines changed

parcels/basegrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def axes(self) -> list[str]:
159159
@abstractmethod
160160
def get_axis_dim(self, axis: str) -> int:
161161
"""
162-
Return the dimensionality (number of cells/edges) along a specific axis.
162+
Return the dimensionality (number of cells/faces) along a specific axis.
163163
164164
Parameters
165165
----------

parcels/xgrid.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
_XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis]
1919

2020

21-
def get_cell_edge_count_along_dim(axis: xgcm.Axis | None) -> int:
22-
if axis is None:
23-
return 1
21+
def get_cell_count_along_dim(axis: xgcm.Axis) -> int:
2422
first_coord = list(axis.coords.items())[0]
2523
_, coord_var = first_coord
2624

27-
return axis._ds[coord_var].size
25+
return axis._ds[coord_var].size - 1
2826

2927

3028
def get_time(axis: xgcm.Axis) -> npt.NDArray:
@@ -129,7 +127,7 @@ def get_axis_dim(self, axis: _XGRID_AXES) -> int:
129127
if axis not in self.axes:
130128
raise ValueError(f"Axis {axis!r} is not part of this grid. Available axes: {self.axes}")
131129

132-
return get_cell_edge_count_along_dim(self.xgcm_grid.axes.get(axis))
130+
return get_cell_count_along_dim(self.xgcm_grid.axes[axis])
133131

134132
@property
135133
def _z4d(self) -> Literal[0, 1]:

tests/v4/test_xgrid.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
GridTestCase(datasets["ds_2d_left"], "lat", datasets["ds_2d_left"].YG.values),
1717
GridTestCase(datasets["ds_2d_left"], "depth", datasets["ds_2d_left"].ZG.values),
1818
GridTestCase(datasets["ds_2d_left"], "time", datasets["ds_2d_left"].time.values.astype(np.float64) / 1e9),
19-
GridTestCase(datasets["ds_2d_left"], "xdim", X),
20-
GridTestCase(datasets["ds_2d_left"], "ydim", Y),
21-
GridTestCase(datasets["ds_2d_left"], "zdim", Z),
19+
GridTestCase(datasets["ds_2d_left"], "xdim", X - 1),
20+
GridTestCase(datasets["ds_2d_left"], "ydim", Y - 1),
21+
GridTestCase(datasets["ds_2d_left"], "zdim", Z - 1),
2222
]
2323

2424

@@ -53,9 +53,9 @@ def test_xgrid_axes(ds):
5353
@pytest.mark.parametrize("ds", [datasets["ds_2d_left"]])
5454
def test_xgrid_get_axis_dim(ds):
5555
grid = XGrid(xgcm.Grid(ds, periodic=False))
56-
assert grid.get_axis_dim("Z") == Z
57-
assert grid.get_axis_dim("Y") == Y
58-
assert grid.get_axis_dim("X") == X
56+
assert grid.get_axis_dim("Z") == Z - 1
57+
assert grid.get_axis_dim("Y") == Y - 1
58+
assert grid.get_axis_dim("X") == X - 1
5959

6060

6161
def test_invalid_xgrid_field_array():

0 commit comments

Comments
 (0)