Skip to content

Commit 0003519

Browse files
Fixing support for 1D depth grids (without lon or lat)
1 parent c5cb140 commit 0003519

File tree

4 files changed

+39
-21
lines changed

4 files changed

+39
-21
lines changed

src/parcels/_core/field.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
217217
_ei = None
218218
else:
219219
_ei = particles.ei[:, self.igrid]
220+
z = np.atleast_1d(z)
221+
y = np.atleast_1d(y)
222+
x = np.atleast_1d(x)
220223

221224
particle_positions, grid_positions = _get_positions(self, time, z, y, x, particles, _ei)
222225

@@ -300,6 +303,9 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
300303
_ei = None
301304
else:
302305
_ei = particles.ei[:, self.igrid]
306+
z = np.atleast_1d(z)
307+
y = np.atleast_1d(y)
308+
x = np.atleast_1d(x)
303309

304310
particle_positions, grid_positions = _get_positions(self.U, time, z, y, x, particles, _ei)
305311

src/parcels/_core/xgrid.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -289,22 +289,13 @@ def search(self, z, y, x, ei=None):
289289
else:
290290
zi, zeta = np.zeros(z.shape, dtype=int), np.zeros(z.shape, dtype=float)
291291

292-
if ds.lon.ndim == 1:
293-
yi, eta = _search_1d_array(ds.lat.values, y)
294-
xi, xsi = _search_1d_array(ds.lon.values, x)
295-
return {
296-
"Z": {"index": zi, "bcoord": zeta},
297-
"Y": {"index": yi, "bcoord": eta},
298-
"X": {"index": xi, "bcoord": xsi},
299-
}
292+
if "X" in self.axes and "Y" in self.axes and ds.lon.ndim == 2:
293+
yi, xi = None, None
294+
if ei is not None:
295+
axis_indices = self.unravel_index(ei)
296+
xi = axis_indices.get("X")
297+
yi = axis_indices.get("Y")
300298

301-
yi, xi = None, None
302-
if ei is not None:
303-
axis_indices = self.unravel_index(ei)
304-
xi = axis_indices.get("X")
305-
yi = axis_indices.get("Y")
306-
307-
if ds.lon.ndim == 2:
308299
yi, eta, xi, xsi = _search_indices_curvilinear_2d(self, y, x, yi, xi)
309300

310301
return {
@@ -313,7 +304,24 @@ def search(self, z, y, x, ei=None):
313304
"X": {"index": xi, "bcoord": xsi},
314305
}
315306

316-
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
307+
if "X" in self.axes and ds.lon.ndim > 2:
308+
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
309+
310+
if "Y" in self.axes:
311+
yi, eta = _search_1d_array(ds.lat.values, y)
312+
else:
313+
yi, eta = np.zeros(y.shape, dtype=int), np.zeros(y.shape, dtype=float)
314+
315+
if "X" in self.axes:
316+
xi, xsi = _search_1d_array(ds.lon.values, x)
317+
else:
318+
xi, xsi = np.zeros(x.shape, dtype=int), np.zeros(x.shape, dtype=float)
319+
320+
return {
321+
"Z": {"index": zi, "bcoord": zeta},
322+
"Y": {"index": yi, "bcoord": eta},
323+
"X": {"index": xi, "bcoord": xsi},
324+
}
317325

318326
@cached_property
319327
def _fpoint_info(self):

src/parcels/interpolators.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,11 @@ def _get_corner_data_Agrid(
8383
xi = np.tile(np.array([xi, xi_1]).flatten(), lenT * lenZ * 2)
8484

8585
# Create DataArrays for indexing
86-
selection_dict = {
87-
axis_dim["X"]: xr.DataArray(xi, dims=("points")),
88-
axis_dim["Y"]: xr.DataArray(yi, dims=("points")),
89-
}
86+
selection_dict = {}
87+
if "X" in axis_dim:
88+
selection_dict[axis_dim["X"]] = xr.DataArray(xi, dims=("points"))
89+
if "Y" in axis_dim:
90+
selection_dict[axis_dim["Y"]] = xr.DataArray(yi, dims=("points"))
9091
if "Z" in axis_dim:
9192
selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points"))
9293
if "time" in data.dims:

tests/test_xgrid.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ def test_invalid_depth():
139139

140140
def test_vertical1D_field():
141141
nz = 11
142-
ds = xr.Dataset({"z1d": (["depth"], np.linspace(0, 10, nz))}, coords={"depth": np.linspace(0, 1, nz)})
142+
ds = xr.Dataset(
143+
{"z1d": (["depth"], np.linspace(0, 10, nz))},
144+
coords={"depth": (["depth"], np.linspace(0, 1, nz), {"axis": "Z"})},
145+
)
143146
grid = XGrid(xgcm.Grid(ds, **_DEFAULT_XGCM_KWARGS))
144147
field = Field("z1d", ds["z1d"], grid)
145148

0 commit comments

Comments
 (0)