Skip to content

Commit d6fae2e

Browse files
Illviljanpre-commit-ci[bot]dcherian
authored
interp - Prefer broadcast over reindex when possible (#10554)
* Prefer broadcast over reindex when possible * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataset.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataset.py * Update dataset.py * Update interp.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_interp.py * Update whats-new.rst * add test for copy vs view * deep copy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Deepak Cherian <[email protected]> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 1e145e9 commit d6fae2e

File tree

4 files changed

+61
-11
lines changed

4 files changed

+61
-11
lines changed

asv_bench/benchmarks/interp.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,37 @@ def setup(self, *args, **kwargs):
2525
"var1": (("x", "y"), randn_xy),
2626
"var2": (("x", "t"), randn_xt),
2727
"var3": (("t",), randn_t),
28+
"var4": (("z",), np.array(["text"])),
29+
"var5": (("k",), np.array(["a", "b", "c"])),
2830
},
2931
coords={
3032
"x": np.arange(nx),
3133
"y": np.linspace(0, 1, ny),
3234
"t": pd.date_range("1970-01-01", periods=nt, freq="D"),
3335
"x_coords": ("x", np.linspace(1.1, 2.1, nx)),
36+
"z": np.array([1]),
37+
"k": np.linspace(0, nx, 3),
3438
},
3539
)
3640

3741
@parameterized(["method", "is_short"], (["linear", "cubic"], [True, False]))
38-
def time_interpolation(self, method, is_short):
42+
def time_interpolation_numeric_1d(self, method, is_short):
3943
new_x = new_x_short if is_short else new_x_long
40-
self.ds.interp(x=new_x, method=method).load()
44+
self.ds.interp(x=new_x, method=method).compute()
4145

4246
@parameterized(["method"], (["linear", "nearest"]))
43-
def time_interpolation_2d(self, method):
44-
self.ds.interp(x=new_x_long, y=new_y_long, method=method).load()
47+
def time_interpolation_numeric_2d(self, method):
48+
self.ds.interp(x=new_x_long, y=new_y_long, method=method).compute()
49+
50+
@parameterized(["is_short"], ([True, False]))
51+
def time_interpolation_string_scalar(self, is_short):
52+
new_z = new_x_short if is_short else new_x_long
53+
self.ds.interp(z=new_z).compute()
54+
55+
@parameterized(["is_short"], ([True, False]))
56+
def time_interpolation_string_1d(self, is_short):
57+
new_k = new_x_short if is_short else new_x_long
58+
self.ds.interp(k=new_k).compute()
4559

4660

4761
class InterpolationDask(Interpolation):

doc/whats-new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ Internal Changes
4242
~~~~~~~~~~~~~~~~
4343

4444

45+
Performance
46+
~~~~~~~~~~~
47+
- Speed up non-numeric scalars when calling :py:meth:`Dataset.interp`. (:issue:`10054`, :pull:`10554`)
48+
By `Jimmy Westling <https://github.com/illviljan>`_.
49+
4550
.. _whats-new.2025.07.1:
4651

4752
v2025.07.1 (July 09, 2025)

xarray/core/dataset.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3851,13 +3851,22 @@ def _validate_interp_indexer(x, new_x):
38513851
var_indexers = {k: v for k, v in use_indexers.items() if k in var.dims}
38523852
variables[name] = missing.interp(var, var_indexers, method, **kwargs)
38533853
elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims):
3854-
# For types that we do not understand do stepwise
3855-
# interpolation to avoid modifying the elements.
3856-
# reindex the variable instead because it supports
3857-
# booleans and objects and retains the dtype but inside
3858-
# this loop there might be some duplicate code that slows it
3859-
# down, therefore collect these signals and run it later:
3860-
reindex_vars.append(name)
3854+
if all(var.sizes[d] == 1 for d in (use_indexers.keys() & var.dims)):
3855+
# Broadcastable, can be handled quickly without reindex:
3856+
to_broadcast = (var.squeeze(),) + tuple(
3857+
dest for _, dest in use_indexers.values()
3858+
)
3859+
variables[name] = broadcast_variables(*to_broadcast)[0].copy(
3860+
deep=True
3861+
)
3862+
else:
3863+
# For types that we do not understand do stepwise
3864+
# interpolation to avoid modifying the elements.
3865+
# reindex the variable instead because it supports
3866+
# booleans and objects and retains the dtype but inside
3867+
# this loop there might be some duplicate code that slows it
3868+
# down, therefore collect these signals and run it later:
3869+
reindex_vars.append(name)
38613870
elif all(d not in indexers for d in var.dims):
38623871
# For anything else we can only keep variables if they
38633872
# are not dependent on any coords that are being

xarray/tests/test_interp.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,28 @@ def test_interp1d_complex_out_of_bounds() -> None:
10651065
assert_identical(actual, expected)
10661066

10671067

1068+
@requires_scipy
1069+
def test_interp_non_numeric_scalar() -> None:
1070+
ds = xr.Dataset(
1071+
{
1072+
"non_numeric": ("time", np.array(["a"])),
1073+
},
1074+
coords={"time": (np.array([0]))},
1075+
)
1076+
actual = ds.interp(time=np.linspace(0, 3, 3))
1077+
1078+
expected = xr.Dataset(
1079+
{
1080+
"non_numeric": ("time", np.array(["a", "a", "a"])),
1081+
},
1082+
coords={"time": np.linspace(0, 3, 3)},
1083+
)
1084+
xr.testing.assert_identical(actual, expected)
1085+
1086+
# Make sure the array is a copy:
1087+
assert actual["non_numeric"].data.base is None
1088+
1089+
10681090
@requires_scipy
10691091
def test_interp_non_numeric_1d() -> None:
10701092
ds = xr.Dataset(

0 commit comments

Comments
 (0)