Skip to content

Commit 9bb0302

Browse files
authored
Always force dask arrays to float in missing.interp_func (#4771)
* scipy.interpolate.interp1d always forces to float. * Copy type-check from scipy.interpolate.interp1d * Update missing.py * Test that pre- and post-compute dtypes matches * Update test_missing.py
1 parent 81ed507 commit 9bb0302

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

xarray/core/missing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,13 @@ def interp_func(var, x, new_x, method, kwargs):
730730
# if usefull, re-use localize for each chunk of new_x
731731
localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None)
732732

733+
# scipy.interpolate.interp1d always forces to float.
734+
# Use the same check for blockwise as well:
735+
if not issubclass(var.dtype.type, np.inexact):
736+
dtype = np.float_
737+
else:
738+
dtype = var.dtype
739+
733740
return da.blockwise(
734741
_dask_aware_interpnd,
735742
out_ind,
@@ -738,7 +745,7 @@ def interp_func(var, x, new_x, method, kwargs):
738745
interp_kwargs=kwargs,
739746
localize=localize,
740747
concatenate=True,
741-
dtype=var.dtype,
748+
dtype=dtype,
742749
new_axes=new_axes,
743750
)
744751

xarray/tests/test_missing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,20 @@ def test_interpolate_dask_raises_for_invalid_chunk_dim():
370370
da.interpolate_na("time")
371371

372372

373+
@requires_dask
374+
@requires_scipy
375+
@pytest.mark.parametrize("dtype, method", [(int, "linear"), (int, "nearest")])
376+
def test_interpolate_dask_expected_dtype(dtype, method):
377+
da = xr.DataArray(
378+
data=np.array([0, 1], dtype=dtype),
379+
dims=["time"],
380+
coords=dict(time=np.array([0, 1])),
381+
).chunk(dict(time=2))
382+
da = da.interp(time=np.array([0, 0.5, 1, 2]), method=method)
383+
384+
assert da.dtype == da.compute().dtype
385+
386+
373387
@requires_bottleneck
374388
def test_ffill():
375389
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")

0 commit comments

Comments
 (0)