Skip to content

Commit fb12bd1

Browse files
committed
update docstrings and tests; change 'recursive' to 'chained' in multiscale signature.
1 parent 63ddaa7 commit fb12bd1

File tree

2 files changed

+92
-65
lines changed

2 files changed

+92
-65
lines changed

src/xarray_multiscale/multiscale.py

Lines changed: 90 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,60 @@
44
from xarray import DataArray
55
from typing import Any, List, Optional, Tuple, Union, Sequence, Callable, Dict
66
from scipy.interpolate import interp1d
7-
from dask.array.core import slices_from_chunks, normalize_chunks
87
from dask.array import coarsen
98

109

11-
1210
def multiscale(
1311
array: Any,
1412
reduction: Callable[[Any], Any],
1513
scale_factors: Union[Sequence[int], int],
1614
pad_mode: Optional[str] = None,
1715
preserve_dtype: bool = True,
1816
chunks: Optional[Union[Sequence[int], Dict[str, int]]] = None,
19-
recursive: bool = False,
17+
chained: bool = True,
2018
) -> List[DataArray]:
2119
"""
22-
Lazily generate a multiscale representation of an array
20+
Generate a lazy, coordinate-aware multiscale representation of an array.
2321
2422
Parameters
2523
----------
26-
array: ndarray to be downscaled.
24+
array : numpy array, dask array, or xarray DataArray
25+
The array to be downscaled
26+
27+
reduction : callable
28+
A function that aggregates chunks of data over windows. See the documentation of `dask.array.coarsen` for the expected
29+
signature of this callable.
2730
28-
reduction: a function that aggregates data over windows.
31+
scale_factors : iterable of ints
32+
The desired downscaling factors, one for each axis.
2933
30-
scale_factors: an iterable of integers that specifies how much to downscale each axis of the array.
34+
pad_mode : string or None, default=None
35+
How arrays should be padded prior to downscaling in order to ensure that each array dimension
36+
is evenly divisible by the respective scale factor. When set to `None` (default), the input will be sliced before downscaling
37+
if its dimensions are not divisible by `scale_factors`.
3138
32-
pad_mode: How (or if) the input should be padded. When set to `None` the input will be trimmed as needed.
39+
preserve_dtype : bool, default=True
40+
Determines whether the multiresolution arrays are all cast to the same dtype as the input.
3341
34-
preserve_dtype: boolean, defaults to True, determines whether lower levels of the pyramid are coerced to the same dtype as the input. This assumes that
35-
the reduction function accepts a "dtype" kwarg, e.g. numpy.mean(x, dtype='int').
42+
chunks : sequence or dict of ints, or None, default=None.
43+
If `chunks` is supplied, all output arrays are returned with this chunking. If not None, this
44+
argument is passed directly to the `xarray.DataArray.chunk` method of each output array.
3645
37-
chunks: Sequence or Dict of ints, defaults to None. If `chunks` is supplied, all DataArrays are rechunked with these chunks before being returned.
46+
chained : bool, default=True
47+
If True (default), the nth downscaled array is generated by applying the reduction function on the n-1th
48+
downscaled array with the user-supplied `scale_factors`. This means that the nth downscaled array directly depends on the n-1th
49+
downscaled array. Note that nonlinear reductions like the windowed mode may give inaccurate results with `chained` set to True.
3850
39-
recursive: boolean, defaults to False. ToDo
51+
If False, the nth downscaled array is generated by applying the reduction function on the 0th downscaled array
52+
(i.e., the input array) with the `scale_factors` raised to the nth power. This means that the nth downscaled array directly
53+
depends on the input array.
4054
41-
Returns a list of DataArrays, one per level of downscaling. These DataArrays have `coords` properties that track the changing offset (if any)
42-
induced by the downsampling operation. Additionally, the scale factors are stored each DataArray's attrs propery under the key `scale_factors`
55+
Returns
4356
-------
57+
result : list of DataArrays
58+
The `coords` attribute of these DataArrays properties that track the changing offset (if any)
59+
induced by the downsampling operation. Additionally, the scale factors are stored each DataArray's attrs propery under the key `scale_factors`
60+
4461
4562
"""
4663
needs_padding = pad_mode != None
@@ -56,19 +73,16 @@ def multiscale(
5673
else:
5774
padded_shape = prepad(array, scale_factors, pad_mode=pad_mode).shape
5875

59-
# figure out the maximum depth
6076
levels = range(
6177
0, 1 + get_downscale_depth(padded_shape, scale_factors, pad=needs_padding)
6278
)
63-
scales = tuple(
64-
tuple(s ** l for s in scale_factors) for l in levels
65-
)
79+
scales = tuple(tuple(s ** l for s in scale_factors) for l in levels)
6680
result = [_ingest_array(array, scales[0])]
6781

6882
for level in levels[1:]:
69-
if recursive:
83+
if chained:
7084
scale = scale_factors
71-
downscaled = downscale(result[-1], reduction, scale, pad_mode=pad_mode)
85+
downscaled = downscale(result[-1], reduction, scale, pad_mode=pad_mode)
7286
else:
7387
scale = scales[level]
7488
downscaled = downscale(result[0], reduction, scale, pad_mode=pad_mode)
@@ -80,14 +94,21 @@ def multiscale(
8094
if chunks is not None:
8195
if isinstance(chunks, Sequence):
8296
_chunks = {k: v for k, v in zip(result[0].dims, chunks)}
83-
else:
97+
elif isinstance(chunks, dict):
8498
_chunks = chunks
99+
else:
100+
raise ValueError(
101+
f"Chunks must be an instance or a dict, not {type(chunks)}"
102+
)
85103
result = [r.chunk(_chunks) for r in result]
86104

87105
return result
88106

89107

90108
def _ingest_array(array: Any, scales: Sequence[int]):
109+
"""
110+
Ingest an array in preparation for downscaling
111+
"""
91112
if hasattr(array, "coords"):
92113
# if the input is a xarray.DataArray, assign a new variable to the DataArray and use the variable
93114
# `array` to refer to the data property of that array
@@ -101,7 +122,7 @@ def _ingest_array(array: Any, scales: Sequence[int]):
101122
data = da.asarray(array)
102123
dims = tuple(f"dim_{d}" for d in range(data.ndim))
103124
coords = {
104-
dim: DataArray(offset + np.arange(s, dtype="float32"), dims=dim)
125+
dim: DataArray(offset + np.arange(s, dtype="float"), dims=dim)
105126
for dim, s, offset in zip(dims, array.shape, get_downsampled_offset(scales))
106127
}
107128
name = None
@@ -118,7 +139,13 @@ def even_padding(length: int, window: int) -> int:
118139
Parameters
119140
----------
120141
length : int
121-
window: int
142+
143+
window : int
144+
145+
Returns
146+
-------
147+
int
148+
Value that, when added to `length`, results in a sum that is evenly divided by `window`
122149
"""
123150
return (window - (length % window)) % window
124151

@@ -132,8 +159,10 @@ def logn(x: float, n: float) -> float:
132159
x : float or int.
133160
n: float or int.
134161
135-
Returns np.log(x) / np.log(n)
162+
Returns
136163
-------
164+
float
165+
np.log(x) / np.log(n)
137166
138167
"""
139168
result: float = np.log(x) / np.log(n)
@@ -147,20 +176,25 @@ def prepad(
147176
rechunk: bool = True,
148177
) -> da.array:
149178
"""
150-
Pad an array such that its new dimensions are evenly divisible by some integer.
179+
Lazily pad an array such that its new dimensions are evenly divisible by some integer.
151180
152181
Parameters
153182
----------
154-
array: An ndarray that will be padded.
183+
array : ndarray
184+
Array that will be padded.
155185
156-
scale_factors: An iterable of integers. The output array is guaranteed to have dimensions that are each evenly divisible
157-
by the corresponding scale factor, and chunks that are smaller than or equal to the scale factor (if the array has chunks)
186+
scale_factors : Sequence of ints
187+
The output array is guaranteed to have dimensions that are each evenly divisible
188+
by the corresponding scale factor, and chunks that are smaller than or equal
189+
to the scale factor (if the array has chunks)
158190
159-
mode: String. The edge mode used by the padding routine. See `dask.array.pad` for more documentation.
191+
pad_mode : str
192+
The edge mode used by the padding routine. This parameter will be passed to
193+
`dask.array.pad` as the `mode` keyword.
160194
161-
Returns a dask array with padded dimensions.
195+
Returns
162196
-------
163-
197+
dask array
164198
"""
165199

166200
if pad_mode == None:
@@ -198,7 +232,8 @@ def prepad(
198232
extended_coords, dims=k, attrs=old_coord.attrs
199233
)
200234
result = DataArray(
201-
result, coords=new_coords, dims=array.dims, attrs=array.attrs)
235+
result, coords=new_coords, dims=array.dims, attrs=array.attrs
236+
)
202237
return result
203238

204239

@@ -214,16 +249,22 @@ def downscale(
214249
215250
Parameters
216251
----------
217-
array: The narray to be downscaled.
252+
array : numpy array, dask array, xarray DataArray
253+
The array to be downscaled.
218254
219-
reduction: The function to apply to each window of the array.
255+
reduction : callable
256+
A function that aggregates chunks of data over windows. See the documentation of `dask.array.coarsen` for the expected
257+
signature of this callable.
220258
221-
scale_factors: A list if ints specifying how much to downscale the array per dimension.
259+
scale_factors : iterable of ints
260+
The desired downscaling factors, one for each axis.
222261
223-
trim_excess: A boolean that determines whether the size of the input array should be increased or decreased such that
224-
each scale factor tiles its respective array axis. Defaults to False, which will result in the input being padded.
262+
trim_excess : bool, default=False
263+
Whether the size of the input array should be increased or decreased such that
264+
each scale factor tiles its respective array axis. Defaults to False, which will result in the input being padded.
225265
226-
**kwargs: extra kwargs passed to dask.array.coarsen
266+
**kwargs
267+
extra kwargs passed to dask.array.coarsen
227268
228269
Returns the downscaled version of the input as a dask array.
229270
-------
@@ -243,7 +284,6 @@ def downscale(
243284
**kwargs,
244285
)
245286

246-
247287
if isinstance(array, xarray.DataArray):
248288
base_coords = array.coords
249289
new_coords = base_coords
@@ -256,10 +296,19 @@ def downscale(
256296
attrs=base_coords[bc].attrs,
257297
)
258298
for s, bc, offset, sc in zip(
259-
coarsened.shape, base_coords, get_downsampled_offset(scale_factors), scale_factors
299+
coarsened.shape,
300+
base_coords,
301+
get_downsampled_offset(scale_factors),
302+
scale_factors,
260303
)
261304
)
262-
coarsened = DataArray(coarsened, dims=array.dims, coords=new_coords, attrs=array.attrs, name=array.name)
305+
coarsened = DataArray(
306+
coarsened,
307+
dims=array.dims,
308+
coords=new_coords,
309+
attrs=array.attrs,
310+
name=array.name,
311+
)
263312

264313
return coarsened
265314

@@ -324,25 +373,3 @@ def slice_span(sl: slice) -> int:
324373
Measure the length of a slice
325374
"""
326375
return sl.stop - sl.start
327-
328-
329-
def blocked_pyramid(
330-
arr, block_size: Sequence, scale_factors: Sequence[int] = (2, 2, 2), **kwargs
331-
):
332-
full_pyr = multiscale(arr, scale_factors=scale_factors, **kwargs)
333-
slices = slices_from_chunks(normalize_chunks(block_size, arr.shape))
334-
absolute_block_size = tuple(map(slice_span, slices[0]))
335-
336-
results = []
337-
for idx, sl in enumerate(slices):
338-
regions = [
339-
tuple(map(downscale_slice, sl, tuple(np.power(scale_factors, exp))))
340-
for exp in range(len(full_pyr))
341-
]
342-
if tuple(map(slice_span, sl)) == absolute_block_size:
343-
pyr = multiscale(arr[sl], scale_factors=scale_factors, **kwargs)
344-
else:
345-
pyr = [full_pyr[l][r] for l, r in enumerate(regions)]
346-
assert len(pyr) == len(regions)
347-
results.append((regions, pyr))
348-
return results

tests/test_multiscale.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_multiscale():
8181

8282
pyr_trimmed = multiscale(array, np.mean, 2, pad_mode=None)
8383
pyr_padded = multiscale(array, np.mean, 2, pad_mode="reflect")
84-
pyr_trimmed_recursive = multiscale(array, np.mean, 2, pad_mode=None, recursive=True)
84+
pyr_trimmed_unchained = multiscale(array, np.mean, 2, pad_mode=None, chained=False)
8585
assert [p.shape for p in pyr_padded] == [
8686
shape,
8787
(5, 5, 5),
@@ -99,6 +99,6 @@ def test_multiscale():
9999
pyr_trimmed[-2].data.mean().compute(), pyr_trimmed[-1].data.compute().mean()
100100
)
101101
assert np.array_equal(
102-
pyr_trimmed_recursive[-2].data.mean().compute(), pyr_trimmed_recursive[-1].data.compute().mean()
102+
pyr_trimmed_unchained[-2].data.mean().compute(), pyr_trimmed_unchained[-1].data.compute().mean()
103103
)
104104
assert np.allclose(pyr_padded[0].data.mean().compute(), 0.17146776406035666)

0 commit comments

Comments
 (0)