Skip to content

Commit ef721d1

Browse files
committed
formatting
1 parent d75ec94 commit ef721d1

File tree

7 files changed

+215
-35
lines changed

7 files changed

+215
-35
lines changed

.pre-commit-config.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
ci:
2+
autoupdate_commit_msg: "chore: update pre-commit hooks"
3+
autofix_commit_msg: "style: pre-commit fixes"
4+
default_stages: [commit, push]
5+
default_language_version:
6+
python: python3
7+
repos:
8+
- repo: https://github.com/psf/black
9+
rev: 22.12.0
10+
hooks:
11+
- id: black
12+
language_version: python3.9
13+
- repo: https://github.com/pre-commit/pre-commit-hooks
14+
rev: v4.4.0
15+
hooks:
16+
- id: check-yaml

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ which returns this (a collection of DataArrays, each with decreasing size):
4747
Coordinates:
4848
* dim_0 (dim_0) float64 0.5 2.5]
4949
50+
By default, the values of the downsampled arrays are cast to the same data type as the input. This behavior can be changed with the ``preserve_dtype`` keyword argument to ``multiscale``.
5051

5152
Generate a multiscale representation of an ``xarray.DataArray``:
5253

poetry.lock

Lines changed: 128 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ sphinx-issues = "^3.0.1"
2222
pytest-cov = "^3.0.0"
2323
pytest = "^7.1.2"
2424
mypy = "^0.971"
25+
pre-commit = "^3.0.0"
2526

2627
[build-system]
2728
requires = ["poetry>=0.12"]

src/xarray_multiscale/chunks.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ def normalize_chunks(
3838
chunk_size = _chunk_size
3939

4040
new_chunks = map(
41-
tz.first, da.core.normalize_chunks(chunk_size, array.shape, dtype=array.dtype)
41+
tz.first,
42+
da.core.normalize_chunks(
43+
chunk_size,
44+
array.shape,
45+
dtype=array.dtype,
46+
previous_chunks=array.data.chunksize,
47+
),
4248
)
4349

4450
result = tuple(new_chunks)

src/xarray_multiscale/multiscale.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Hashable, List, Sequence, Union
1+
from typing import Any, Dict, Hashable, List, Literal, Sequence, Union
22

33
import numpy as np
44
import numpy.typing as npt
@@ -13,13 +13,15 @@
1313
from xarray_multiscale.reducers import WindowedReducer
1414
from xarray_multiscale.util import adjust_shape, broadcast_to_rank, logn
1515

16+
ChunkOption = Literal["preserve", "auto"]
17+
1618

1719
def multiscale(
1820
array: npt.NDArray[Any],
1921
reduction: WindowedReducer,
2022
scale_factors: Union[Sequence[int], int],
2123
preserve_dtype: bool = True,
22-
chunks: Union[str, Sequence[int], Dict[Hashable, int]] = "auto",
24+
chunks: Union[str, Sequence[int], Dict[Hashable, int]] = "preserve",
2325
chained: bool = True,
2426
) -> List[DataArray]:
2527
"""
@@ -44,10 +46,11 @@ def multiscale(
4446
input array. If False, output arrays will have data type determined
4547
by the output of the reduction function.
4648
47-
chunks : sequence or dict of ints, or the string "auto" (default)
49+
chunks : sequence or dict of ints, or the string "preserve" (default)
4850
Set the chunking of the output arrays. Applies only to dask arrays.
49-
If `chunks` is set to "auto" (the default), then chunk sizes will
50-
decrease with each level of downsampling.
51+
If `chunks` is set to "preserve" (the default), then chunk sizes will
52+
decrease with each level of downsampling. Otherwise, this argument is
53+
passed to `xarray_multiscale.chunks.normalize_chunks`.
5154
5255
Otherwise, this keyword argument will be passed to the
5356
`xarray.DataArray.chunk` method for each output array,
@@ -108,7 +111,7 @@ def multiscale(
108111
source = result[0]
109112
result.append(downscale(source, reduction, scale, preserve_dtype))
110113

111-
if darray.chunks is not None:
114+
if darray.chunks is not None and chunks != "preserve":
112115
new_chunks = [normalize_chunks(r, chunks) for r in result]
113116
result = [r.chunk(ch) for r, ch in zip(result, new_chunks)]
114117

tests/test_multiscale.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
from xarray import DataArray
55
from xarray.testing import assert_equal
66

7-
from xarray_multiscale.multiscale import (adjust_shape, downsampling_depth,
8-
downscale, downscale_coords,
9-
downscale_dask, multiscale)
7+
from xarray_multiscale.multiscale import (
8+
adjust_shape,
9+
downsampling_depth,
10+
downscale,
11+
downscale_coords,
12+
downscale_dask,
13+
multiscale,
14+
)
1015
from xarray_multiscale.reducers import windowed_mean
1116

1217

@@ -130,33 +135,54 @@ def test_multiscale(ndim: int, chained: bool):
130135

131136
def test_chunking():
132137
ndim = 3
133-
shape = (9,) * ndim
134-
base_array = da.zeros(shape, chunks=(1,) * ndim)
135-
chunks = (1,) * ndim
138+
shape = (16,) * ndim
139+
chunks = (4,) * ndim
140+
base_array = da.zeros(shape, chunks=chunks)
136141
reducer = windowed_mean
137-
multi = multiscale(base_array, reducer, 2, chunks=chunks)
138-
assert all([m.data.chunksize == chunks for m in multi])
139-
140-
chunks = (3,) * ndim
141-
multi = multiscale(base_array, reducer, 2, chunks=chunks)
142-
for m in multi:
143-
assert m.data.chunksize == chunks or m.data.chunksize == m.data.shape
144-
145-
chunks = (3,) * ndim
146-
multi = multiscale(base_array, reducer, 2, chunks=chunks)
147-
for m in multi:
148-
assert (
149-
np.greater_equal(m.data.chunksize, chunks).all()
150-
or m.data.chunksize == m.data.shape
151-
)
142+
scale_factors = (2,) * ndim
143+
144+
multi = multiscale(base_array, reducer, scale_factors)
145+
expected_chunks = [
146+
np.floor_divide(chunks, [s**idx for s in scale_factors])
147+
for idx, m in enumerate(multi)
148+
]
149+
expected_chunks = [
150+
x
151+
if np.all(x)
152+
else [
153+
1,
154+
]
155+
* ndim
156+
for x in expected_chunks
157+
]
158+
assert all(
159+
[np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)]
160+
)
161+
162+
multi = multiscale(base_array, reducer, scale_factors, chunks=chunks)
163+
expected_chunks = [
164+
chunks if np.greater(m.shape, chunks).all() else m.shape
165+
for idx, m in enumerate(multi)
166+
]
167+
assert all(
168+
[np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)]
169+
)
170+
171+
chunks = (3, -1, -1)
172+
multi = multiscale(base_array, reducer, scale_factors, chunks=chunks)
173+
expected_chunks = [
174+
(min(chunks[0], m.shape[0]), m.shape[1], m.shape[2]) for m in multi
175+
]
176+
assert all(
177+
[np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)]
178+
)
152179

153180
chunks = 3
154-
multi = multiscale(base_array, reducer, 2, chunks=chunks)
155-
for m in multi:
156-
assert (
157-
np.greater_equal(m.data.chunksize, (chunks,) * ndim).all()
158-
or m.data.chunksize == m.data.shape
159-
)
181+
multi = multiscale(base_array, reducer, scale_factors, chunks=chunks)
182+
expected_chunks = [tuple(min(chunks, s) for s in m.shape) for m in multi]
183+
assert all(
184+
[np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)]
185+
)
160186

161187

162188
def test_coords():

0 commit comments

Comments
 (0)