Skip to content

Commit cde96c1

Browse files
committed
add vertical dim arg
1 parent 9625bbf commit cde96c1

File tree

1 file changed

+46
-29
lines changed

1 file changed

+46
-29
lines changed

src/earthkit/meteo/vertical/interpolation.py

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def interpolate_monotonic(
3232
coord: xr.DataArray,
3333
target_coord: TargetCoordinates,
3434
interpolation: Literal["linear", "log", "nearest"] = "linear",
35+
vertical_dim: str = "z",
3536
) -> xr.DataArray:
3637
"""Interpolate a field to isolevels of a monotonic target field.
3738
@@ -48,6 +49,8 @@ def interpolate_monotonic(
4849
target coordinate definition
4950
interpolation : str
5051
interpolation algorithm, one of {"linear", "log", "nearest"}
52+
vertical_dim : str
53+
name of the vertical dimension
5154
5255
Returns
5356
-------
@@ -60,8 +63,8 @@ def interpolate_monotonic(
6063
raise ValueError(f"Unknown interpolation: {interpolation}")
6164

6265
# ... determine direction of target field
63-
dtdz = coord.diff("z")
64-
positive = np.all(dtdz > 0)
66+
dtdz = coord.diff(vertical_dim)
67+
positive = np.all(dtdz > 0).item()
6568

6669
if not positive and not np.all(dtdz < 0):
6770
raise ValueError("target data is not monotonic in the vertical dimension")
@@ -75,8 +78,8 @@ def interpolate_monotonic(
7578

7679
# Interpolate
7780
# ... prepare interpolation
78-
tkm1 = coord.shift(z=1)
79-
fkm1 = data.shift(z=1)
81+
tkm1 = coord.shift({vertical_dim: 1})
82+
fkm1 = data.shift({vertical_dim: 1})
8083

8184
# ... loop through target values
8285
for target_idx, t0 in enumerate(target_coord.values):
@@ -87,19 +90,19 @@ def interpolate_monotonic(
8790
# ... note that if the condition above is not fulfilled, minind will
8891
# be set to k_top
8992
if positive:
90-
t2 = coord.where((coord < t0) & (tkm1 >= t0))
93+
t2 = coord.where((coord >= t0) & (tkm1 <= t0))
9194
else:
92-
t2 = coord.where((coord > t0) & (tkm1 <= t0))
95+
t2 = coord.where((coord <= t0) & (tkm1 >= t0))
9396

94-
minind = t2.fillna(np.inf).argmin(dim="z")
97+
minind = t2.fillna(np.inf).argmin(dim=vertical_dim)
9598

9699
# ... extract pressure and field at level k
97-
t2 = t2[{"z": minind}]
98-
f2 = data[{"z": minind}]
100+
t2 = t2[{vertical_dim: minind}]
101+
f2 = data[{vertical_dim: minind}]
99102
# ... extract pressure and field at level k-1
100103
# ... note that f1 and p1 are both undefined, if minind equals k_top
101-
f1 = fkm1[{"z": minind}]
102-
t1 = tkm1[{"z": minind}]
104+
f1 = fkm1[{vertical_dim: minind}]
105+
t1 = tkm1[{vertical_dim: minind}]
103106

104107
# ... compute the interpolation weights
105108
if interpolation == "linear":
@@ -120,7 +123,7 @@ def interpolate_monotonic(
120123
ratio = xr.where(np.abs(t0 - t1) >= np.abs(t0 - t2), 1.0, 0.0)
121124

122125
# ... interpolate and update field_on_target
123-
field_on_target[{"z": target_idx}] = (1.0 - ratio) * f1 + ratio * f2
126+
field_on_target[{vertical_dim: target_idx}] = (1.0 - ratio) * f1 + ratio * f2
124127

125128
return field_on_target
126129

@@ -131,6 +134,7 @@ def interpolate_to_pressure_levels(
131134
target_p: Sequence[float],
132135
target_p_units: Literal["Pa", "hPa"] = "Pa",
133136
interpolation: Literal["linear", "log", "nearest"] = "linear",
137+
vertical_dim: str = "z",
134138
) -> xr.DataArray:
135139
"""Interpolate a field from model (k) levels to pressure coordinates.
136140
@@ -149,6 +153,8 @@ def interpolate_to_pressure_levels(
149153
pressure target coordinate units
150154
interpolation : str
151155
interpolation algorithm, one of {"linear", "log", "nearest"}
156+
vertical_dim : str
157+
name of the vertical dimension
152158
153159
Returns
154160
-------
@@ -182,7 +188,7 @@ def interpolate_to_pressure_levels(
182188
values=target_values.tolist(),
183189
)
184190

185-
return interpolate_monotonic(data, p, target, interpolation)
191+
return interpolate_monotonic(data, p, target, interpolation, vertical_dim)
186192

187193

188194
def interpolate_sleve_to_coord_levels(
@@ -191,6 +197,7 @@ def interpolate_sleve_to_coord_levels(
191197
coord: xr.DataArray,
192198
target_coord: TargetCoordinates,
193199
folding_mode: Literal["low_fold", "high_fold", "undef_fold"] = "undef_fold",
200+
vertical_dim: str = "z",
194201
) -> xr.DataArray:
195202
"""Interpolate a field from sleve levels to coordinates w.r.t. an arbitrary field.
196203
@@ -210,6 +217,8 @@ def interpolate_sleve_to_coord_levels(
210217
folding_mode : str
211218
handle when the target is observed multiple times in a column,
212219
one of {"low_fold", "high_fold", "undef_fold"}
220+
vertical_dim : str
221+
name of the vertical dimension
213222
214223
Returns
215224
-------
@@ -227,46 +236,48 @@ def interpolate_sleve_to_coord_levels(
227236
h_max = 100000.0
228237

229238
# Prepare output field on target coordinates
230-
field_on_target = _init_field_with_vcoord(data.broadcast_like(coord), target_coord, np.nan)
239+
field_on_target = _init_field_with_vcoord(
240+
data.broadcast_like(coord), target_coord, np.nan, vertical_dim=vertical_dim
241+
)
231242

232243
# Interpolate
233244
# ... prepare interpolation
234-
tkm1 = coord.shift(z=1)
235-
fkm1 = data.shift(z=1)
245+
tkm1 = coord.shift({vertical_dim: 1})
246+
fkm1 = data.shift({vertical_dim: 1})
236247

237248
# ... loop through tc values
238249
for t_idx, t0 in enumerate(target_coord.values):
239-
folding_coord_exception = xr.full_like(h[{"z": 0}], False)
250+
folding_coord_exception = xr.full_like(h[{vertical_dim: 0}], False)
240251
# ... find the height field where target is >= t0 on level k and was <= t0
241252
# on level k-1 or where theta is <= th0 on level k
242253
# and was >= th0 on level k-1
243254
ht = h.where(((coord >= t0) & (tkm1 <= t0)) | ((coord <= t0) & (tkm1 >= t0)))
244255
if folding_mode == "undef_fold":
245256
# ... find condition where more than one interval is found, which
246257
# contains the target coordinate value
247-
tmp = xr.where(ht.notnull(), 1, 0).sum(dim=["z"])
258+
tmp = xr.where(ht.notnull(), 1, 0).sum(dim=[vertical_dim])
248259
folding_coord_exception = tmp.where(tmp > 1).notnull()
249260
if folding_mode in ("low_fold", "undef_fold"):
250261
# ... extract the index k of the smallest height at which
251262
# the condition is fulfilled
252-
tcind = ht.fillna(h_max).argmin(dim="z")
263+
tcind = ht.fillna(h_max).argmin(dim=vertical_dim)
253264
if folding_mode == "high_fold":
254265
# ... extract the index k of the largest height at which the condition
255266
# is fulfilled
256-
tcind = ht.fillna(h_min).argmax(dim="z")
267+
tcind = ht.fillna(h_min).argmax(dim=vertical_dim)
257268

258269
# ... extract theta and field at level k
259-
t2 = coord[{"z": tcind}]
260-
f2 = data[{"z": tcind}]
270+
t2 = coord[{vertical_dim: tcind}]
271+
f2 = data[{vertical_dim: tcind}]
261272
# ... extract theta and field at level k-1
262-
f1 = fkm1[{"z": tcind}]
263-
t1 = tkm1[{"z": tcind}]
273+
f1 = fkm1[{vertical_dim: tcind}]
274+
t1 = tkm1[{vertical_dim: tcind}]
264275

265276
# ... compute the interpolation weights
266277
ratio = xr.where(np.abs(t2 - t1) > 0, (t0 - t1) / (t2 - t1), 0.0)
267278

268279
# ... interpolate and update field on target
269-
field_on_target[{"z": t_idx}] = xr.where(
280+
field_on_target[{vertical_dim: t_idx}] = xr.where(
270281
folding_coord_exception, np.nan, (1.0 - ratio) * f1 + ratio * f2
271282
)
272283

@@ -280,6 +291,7 @@ def interpolate_sleve_to_theta_levels(
280291
target_theta: Sequence[float],
281292
target_t_units: Literal["K", "cK"] = "K",
282293
folding_mode: Literal["low_fold", "high_fold", "undef_fold"] = "undef_fold",
294+
vertical_dim: str = "z",
283295
) -> xr.DataArray:
284296
"""Interpolate a field from sleve levels to potential temperature coordinates.
285297
@@ -301,6 +313,8 @@ def interpolate_sleve_to_theta_levels(
301313
folding_mode : str
302314
handle when the target is observed multiple times in a column,
303315
one of {"low_fold", "high_fold", "undef_fold"}
316+
vertical_dim : str
317+
name of the vertical dimension
304318
305319
Returns
306320
-------
@@ -342,14 +356,15 @@ def interpolate_sleve_to_theta_levels(
342356
values=tc_values.tolist(),
343357
)
344358

345-
return interpolate_sleve_to_coord_levels(data, h, theta, tc, folding_mode)
359+
return interpolate_sleve_to_coord_levels(data, h, theta, tc, folding_mode, vertical_dim)
346360

347361

348362
def _init_field_with_vcoord(
349363
parent: xr.DataArray,
350364
vcoord: TargetCoordinates,
351365
fill_value: Any,
352366
dtype: np.dtype | None = None,
367+
vertical_dim: str = "z",
353368
) -> xr.DataArray:
354369
"""Initialize an xarray.DataArray with new vertical coordinates.
355370
@@ -367,6 +382,8 @@ def _init_field_with_vcoord(
367382
dtype : np.dtype, optional
368383
fill value data type; defaults to None (in this case
369384
the data type is inherited from the parent field)
385+
vertical_dim : str
386+
name of the vertical dimension
370387
371388
Returns
372389
-------
@@ -388,12 +405,12 @@ def _init_field_with_vcoord(
388405
# parent.metadata, typeOfLevel=vcoord.type_of_level
389406
# )
390407
# dims
391-
sizes = dict(parent.sizes.items()) | {"z": vcoord.size}
408+
sizes = dict(parent.sizes.items()) | {vertical_dim: vcoord.size}
392409
# coords
393410
# ... inherit all except for the vertical coordinates
394-
coords = {c: v for c, v in parent.coords.items() if c != "z"}
411+
coords = {c: v for c, v in parent.coords.items() if vertical_dim not in v.dims}
395412
# ... initialize the vertical target coordinates
396-
coords["z"] = xr.IndexVariable("z", vcoord.values)
413+
coords[vertical_dim] = xr.IndexVariable(vertical_dim, vcoord.values)
397414
# dtype
398415
if dtype is None:
399416
dtype = parent.data.dtype

0 commit comments

Comments
 (0)