Skip to content

Commit 63ddaa7

Browse files
committed
add recursive mode
1 parent 6b926c4 commit 63ddaa7

File tree

2 files changed

+19
-27
lines changed

2 files changed

+19
-27
lines changed

src/xarray_multiscale/multiscale.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,26 @@ def multiscale(
6464
tuple(s ** l for s in scale_factors) for l in levels
6565
)
6666
result = [_ingest_array(array, scales[0])]
67-
base_attrs = result[0].attrs
68-
base_coords = result[0].coords
6967

70-
for scale in scales[1:]:
71-
downscaled = downscale(
72-
result[0], reduction, scale, pad_mode=pad_mode, preserve_dtype=preserve_dtype
73-
)
68+
for level in levels[1:]:
69+
if recursive:
70+
scale = scale_factors
71+
downscaled = downscale(result[-1], reduction, scale, pad_mode=pad_mode)
72+
else:
73+
scale = scales[level]
74+
downscaled = downscale(result[0], reduction, scale, pad_mode=pad_mode)
7475
result.append(downscaled)
7576

77+
if preserve_dtype:
78+
result = [r.astype(array.dtype) for r in result]
79+
7680
if chunks is not None:
7781
if isinstance(chunks, Sequence):
7882
_chunks = {k: v for k, v in zip(result[0].dims, chunks)}
7983
else:
8084
_chunks = chunks
8185
result = [r.chunk(_chunks) for r in result]
86+
8287
return result
8388

8489

@@ -202,7 +207,6 @@ def downscale(
202207
reduction: Callable,
203208
scale_factors: Sequence[int],
204209
pad_mode: Optional[str] = None,
205-
preserve_dtype: bool = True,
206210
**kwargs,
207211
) -> DataArray:
208212
"""
@@ -239,8 +243,6 @@ def downscale(
239243
**kwargs,
240244
)
241245

242-
if preserve_dtype:
243-
coarsened = coarsened.astype(array.dtype)
244246

245247
if isinstance(array, xarray.DataArray):
246248
base_coords = array.coords

tests/test_multiscale.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,33 +54,20 @@ def test_downscale_2d():
5454
arr_xarray = DataArray(arr_dask)
5555

5656
downscaled_numpy_float = downscale(
57-
arr_numpy, np.mean, scale, preserve_dtype=False
58-
).compute()
57+
arr_numpy, np.mean, scale).compute()
5958

6059
downscaled_dask_float = downscale(
61-
arr_dask, np.mean, scale, preserve_dtype=False
62-
).compute()
60+
arr_dask, np.mean, scale).compute()
6361

6462
downscaled_xarray_float = downscale(
65-
arr_xarray, np.mean, scale, preserve_dtype=False
66-
).compute()
63+
arr_xarray, np.mean, scale).compute()
6764

6865
answer_float = np.array([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]])
66+
6967
assert np.array_equal(downscaled_numpy_float, answer_float)
7068
assert np.array_equal(downscaled_dask_float, answer_float)
7169
assert np.array_equal(downscaled_xarray_float, answer_float)
7270

73-
downscaled_numpy_int = downscale(
74-
arr_numpy, np.mean, scale, dtype=arr_numpy.dtype
75-
).compute()
76-
downscaled_dask_int = downscale(
77-
arr_dask, np.mean, scale, dtype=arr_numpy.dtype
78-
).compute()
79-
80-
answer_int = answer_float.astype("int")
81-
assert np.array_equal(downscaled_numpy_int, answer_int)
82-
assert np.array_equal(downscaled_dask_int, answer_int)
83-
8471

8572
def test_multiscale():
8673
ndim = 3
@@ -94,7 +81,7 @@ def test_multiscale():
9481

9582
pyr_trimmed = multiscale(array, np.mean, 2, pad_mode=None)
9683
pyr_padded = multiscale(array, np.mean, 2, pad_mode="reflect")
97-
84+
pyr_trimmed_recursive = multiscale(array, np.mean, 2, pad_mode=None, recursive=True)
9885
assert [p.shape for p in pyr_padded] == [
9986
shape,
10087
(5, 5, 5),
@@ -111,4 +98,7 @@ def test_multiscale():
11198
assert np.array_equal(
11299
pyr_trimmed[-2].data.mean().compute(), pyr_trimmed[-1].data.compute().mean()
113100
)
101+
assert np.array_equal(
102+
pyr_trimmed_recursive[-2].data.mean().compute(), pyr_trimmed_recursive[-1].data.compute().mean()
103+
)
114104
assert np.allclose(pyr_padded[0].data.mean().compute(), 0.17146776406035666)

0 commit comments

Comments
 (0)