Skip to content

Commit 3afd9dd

Browse files
committed
Generalise ravel/unravel to any grid
1 parent 60408b7 commit 3afd9dd

File tree

5 files changed

+126
-52
lines changed

5 files changed

+126
-52
lines changed

parcels/basegrid.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from enum import IntEnum
55
from typing import TYPE_CHECKING
66

7+
import numpy as np
8+
79
if TYPE_CHECKING:
810
import numpy as np
911

@@ -67,7 +69,6 @@ def search(self, z: float, y: float, x: float, ei=None) -> dict[str, tuple[int,
6769
"""
6870
...
6971

70-
@abstractmethod
7172
def ravel_index(self, axis_indices: dict[str, int]) -> int:
7273
"""
7374
Convert a dictionary of axis indices to a single encoded index (ei).
@@ -101,9 +102,10 @@ def ravel_index(self, axis_indices: dict[str, int]) -> int:
101102
NotImplementedError
102103
Raised if the method is not implemented for the current grid type.
103104
"""
104-
...
105+
dims = np.array([self.get_axis_dim(axis) for axis in self.axes], dtype=int)
106+
indices = np.array([axis_indices[axis] for axis in self.axes], dtype=int)
107+
return _ravel(dims, indices)
105108

106-
@abstractmethod
107109
def unravel_index(self, ei: int) -> dict[str, int]:
108110
"""
109111
Convert a single encoded index (ei) back to a dictionary of axis indices.
@@ -134,14 +136,19 @@ def unravel_index(self, ei: int) -> dict[str, int]:
134136
NotImplementedError
135137
Raised if the method is not implemented for the current grid type.
136138
"""
137-
...
139+
dims = np.array([self.get_axis_dim(axis) for axis in self.axes], dtype=int)
140+
indices = _unravel(dims, ei)
141+
return dict(zip(self.axes, indices, strict=True))
138142

139143
@property
140144
@abstractmethod
141145
def axes(self) -> list[str]:
142146
"""
143147
Return a list of axis names that are part of this grid.
144148
149+
This list must at least be of length 1, and `get_axis_dim` should
150+
return a valid integer for each axis name in the list.
151+
145152
Returns
146153
-------
147154
list[str]
@@ -170,3 +177,65 @@ def get_axis_dim(self, axis: str) -> int:
170177
If the specified axis is not part of this grid.
171178
"""
172179
...
180+
181+
182+
def _unravel(dims, ei):
183+
"""
184+
Converts a flattened (raveled) index back to multi-dimensional indices.
185+
186+
Args:
187+
dims (1d-array-like): The dimensions along each axis
188+
ei (int): The flattened index to convert
189+
190+
Returns
191+
-------
192+
array-like: Indices along each axis corresponding to the given flattened index
193+
194+
Example:
195+
>>> dims = [2, 3, 4]
196+
>>> ei = 9
197+
>>> unravel(dims, ei)
198+
array([0, 2, 1])
199+
# Calculation:
200+
# i0 = 9 // (3*4) = 9 // 12 = 0
201+
# remainder = 9 % 12 = 9
202+
# i1 = 9 // 4 = 2
203+
# i2 = 9 % 4 = 1
204+
"""
205+
strides = np.cumprod(dims[::-1])[::-1]
206+
207+
indices = np.empty(len(dims), dtype=int)
208+
209+
for i in range(len(dims) - 1):
210+
indices[i] = ei // strides[i + 1]
211+
ei = ei % strides[i + 1]
212+
213+
indices[-1] = ei
214+
return indices
215+
216+
217+
def _ravel(dims, indices):
218+
"""
219+
Converts indices to a flattened (raveled) index.
220+
221+
Args:
222+
dims (1d-array-like): The dimensions along each axis
223+
indices (array-like): Indices along each axis to convert
224+
225+
Returns
226+
-------
227+
int: The flattened index corresponding to the given indices
228+
229+
Example:
230+
>>> dims = [2, 3, 4]
231+
>>> indices = [0, 2, 1]
232+
>>> ravel(dims, indices)
233+
9
234+
# Calculation: 0 * (3 * 4) + 2 * (4) + 1 = 0 + 8 + 1 = 9
235+
"""
236+
strides = np.cumprod(dims[::-1])[::-1]
237+
ei = 0
238+
for i in range(len(dims) - 1):
239+
ei += indices[i] * strides[i + 1]
240+
241+
return ei + indices[-1]

parcels/uxgrid.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,3 @@ def _get_barycentric_coordinates(self, y, x, fi):
114114
bcoord = np.asarray(_barycentric_coordinates(nodes, coord))
115115
err = abs(np.dot(bcoord, nodes[:, 0]) - coord[0]) + abs(np.dot(bcoord, nodes[:, 1]) - coord[1])
116116
return bcoord, err
117-
118-
def ravel_index(self, axis_indices: dict[_UXGRID_AXES, int]):
119-
return axis_indices["FACE"] + self.uxgrid.n_face * axis_indices["Z"]
120-
121-
def unravel_index(self, ei) -> dict[_UXGRID_AXES, int]:
122-
zi = ei // self.uxgrid.n_face
123-
fi = ei % self.uxgrid.n_face
124-
return {"Z": zi, "FACE": fi}

parcels/xgrid.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,22 +187,6 @@ def search(self, z, y, x, ei=None):
187187

188188
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
189189

190-
def ravel_index(self, axis_indices: dict[_XGRID_AXES, int]) -> int:
191-
xi = axis_indices.get("X", 0)
192-
yi = axis_indices.get("Y", 0)
193-
zi = axis_indices.get("Z", 0)
194-
xdim = self.get_axis_dim("X")
195-
ydim = self.get_axis_dim("Y")
196-
return xi + xdim * yi + xdim * ydim * zi
197-
198-
def unravel_index(self, ei) -> dict[_XGRID_AXES, int]:
199-
zi = ei // (self.xdim * self.ydim)
200-
ei = ei % (self.xdim * self.ydim)
201-
202-
yi = ei // self.xdim
203-
xi = ei % self.xdim
204-
return {"Z": zi, "Y": yi, "X": xi}
205-
206190

207191
def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
208192
"""For a given dimension name in a grid, returns the direction axis it is on."""

tests/v4/test_basegrid.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import annotations
2+
3+
import itertools
4+
5+
import numpy as np
6+
import pytest
7+
8+
from parcels.basegrid import BaseGrid
9+
10+
11+
class TestGrid(BaseGrid):
12+
def __init__(self, axis_dim: dict[str, int]):
13+
self.axis_dim = axis_dim
14+
15+
def search(self, z: float, y: float, x: float, ei=None) -> dict[str, tuple[int, float | np.ndarray]]:
16+
pass
17+
18+
@property
19+
def axes(self) -> list[str]:
20+
return list(self.axis_dim.keys())
21+
22+
def get_axis_dim(self, axis: str) -> int:
23+
return self.axis_dim[axis]
24+
25+
26+
@pytest.mark.parametrize(
27+
"grid",
28+
[
29+
TestGrid({"Z": 10, "Y": 20, "X": 30}),
30+
TestGrid({"Z": 5, "Y": 15}),
31+
TestGrid({"Z": 8}),
32+
TestGrid({"Z": 12, "FACE": 25}),
33+
],
34+
)
35+
def test_basegrid_ravel_unravel_index(grid):
36+
axes = grid.axes
37+
dimensionalities = (grid.get_axis_dim(axis) for axis in axes)
38+
all_possible_axis_indices = itertools.product(*[range(dim) for dim in dimensionalities])
39+
40+
encountered_eis = []
41+
42+
for axis_indices_numeric in all_possible_axis_indices:
43+
axis_indices = dict(zip(axes, axis_indices_numeric, strict=True))
44+
45+
ei = grid.ravel_index(axis_indices)
46+
axis_indices_test = grid.unravel_index(ei)
47+
assert axis_indices_test == axis_indices
48+
encountered_eis.append(ei)
49+
50+
encountered_eis = sorted(encountered_eis)
51+
assert len(set(encountered_eis)) == len(encountered_eis), "Raveled indices are not unique."
52+
assert np.allclose(np.diff(np.array(encountered_eis)), 1), "Raveled indices are not consecutive integers."
53+
assert encountered_eis[0] == 0, "Raveled indices do not start at 0."

tests/v4/test_xgrid.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -90,30 +90,6 @@ def test_invalid_lon_lat():
9090
XGrid(xgcm.Grid(ds, periodic=False))
9191

9292

93-
def test_xgrid_ravel_unravel_index():
94-
ds = datasets["ds_2d_left"]
95-
grid = XGrid(xgcm.Grid(ds, periodic=False))
96-
97-
xdim = grid.xdim
98-
ydim = grid.ydim
99-
zdim = grid.zdim
100-
101-
encountered_eis = []
102-
for xi in range(xdim):
103-
for yi in range(ydim):
104-
for zi in range(zdim):
105-
axis_indices = {"Z": zi, "Y": yi, "X": xi}
106-
ei = grid.ravel_index(axis_indices)
107-
axis_indices_test = grid.unravel_index(ei)
108-
assert axis_indices_test == axis_indices
109-
encountered_eis.append(ei)
110-
111-
encountered_eis = sorted(encountered_eis)
112-
assert len(set(encountered_eis)) == len(encountered_eis), "Raveled indices are not unique."
113-
assert np.allclose(np.diff(np.array(encountered_eis)), 1), "Raveled indices are not consecutive integers."
114-
assert encountered_eis[0] == 0, "Raveled indices do not start at 0."
115-
116-
11793
@pytest.mark.parametrize(
11894
"ds",
11995
[

0 commit comments

Comments
 (0)