Skip to content

Commit a0c7d0e

Browse files
committed
pass kwargs down to the reducer
1 parent 11f5c7d commit a0c7d0e

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

src/xarray_multiscale/multiscale.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def multiscale(
3636
chunks: ChunkOption | Sequence[int] | dict[Hashable, int] = "preserve",
3737
chained: bool = True,
3838
namer: Callable[[int], str] = _default_namer,
39+
**kwargs: Any,
3940
) -> list[xarray.DataArray]:
4041
"""
4142
Generate a coordinate-aware multiscale representation of an array.
@@ -90,6 +91,9 @@ def multiscale(
9091
index and return a string. The default function simply prepends the string
9192
representation of the integer with the character "s".
9293
94+
**kwargs: Any
95+
Additional keyword arguments that will be passed to the reduction function.
96+
9397
Returns
9498
-------
9599
result : list[xarray.DataArray]
@@ -127,7 +131,7 @@ def multiscale(
127131
else:
128132
scale = tuple(s**level for s in scale_factors)
129133
source = result[0]
130-
downscaled = downscale(source, reduction, scale, preserve_dtype)
134+
downscaled = downscale(source, reduction, scale, preserve_dtype, **kwargs)
131135
downscaled.name = namer(level)
132136
result.append(downscaled)
133137

@@ -201,7 +205,7 @@ def downscale(
201205
if to_downscale.chunks is not None:
202206
downscaled_data = downscale_dask(to_downscale.data, reduction, scale_factors, **kwargs)
203207
else:
204-
downscaled_data = reduction(to_downscale.data, scale_factors)
208+
downscaled_data = reduction(to_downscale.data, scale_factors, **kwargs)
205209
if preserve_dtype:
206210
downscaled_data = downscaled_data.astype(array.dtype)
207211
downscaled_coords = downscale_coords(to_downscale, scale_factors)

tests/test_multiscale.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dask.array as da
22
import numpy as np
33
import pytest
4+
from src.xarray_multiscale.reducers import windowed_rank
45
from xarray import DataArray
56
from xarray.testing import assert_equal
67

@@ -127,6 +128,14 @@ def test_multiscale(ndim: int, chained: bool):
127128
assert np.array_equal(pyr[0].data, base_array)
128129

129130

131+
@pytest.mark.parametrize("rank", (-1, 0, 1))
132+
def test_multiscale_rank_kwargs(rank: int):
133+
data = np.arange(16)
134+
window_size = (4,)
135+
pyr = multiscale(data, windowed_rank, window_size, rank=rank)
136+
assert np.array_equal(pyr[1].data, windowed_rank(data, window_size=window_size, rank=rank))
137+
138+
130139
def test_chunking():
131140
ndim = 3
132141
shape = (16,) * ndim

tests/test_reducers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_windowed_rank():
6969
window_size = (2, 2, 2)
7070

7171
# 2nd brightest voxel
72-
rank = np.product(window_size) - 2
72+
rank = np.prod(window_size) - 2
7373
answer = np.array([[[7, 7], [7, 7]], [[7, 7], [7, 7]]])
7474
results = windowed_rank(larger_array, window_size, rank)
7575
assert np.array_equal(results, answer)

0 commit comments

Comments
 (0)