Skip to content

Commit d945b29

Browse files
committed
black and fix failure for singleton dimensions
1 parent 229d8bd commit d945b29

File tree

3 files changed

+107
-69
lines changed

3 files changed

+107
-69
lines changed

src/xarray_multiscale/multiscale.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,16 +333,14 @@ def get_downscale_depth(
333333

334334
_scale_factors = np.array(scale_factors).astype("int")
335335
_shape = np.array(shape).astype("int")
336-
if np.all(_scale_factors == 1):
337-
result = 0
338-
elif np.any(_scale_factors > _shape):
336+
valid = (_scale_factors > 1)
337+
if not valid.any():
339338
result = 0
340339
else:
341340
if pad:
342-
depths = np.ceil(logn(shape, scale_factors)).astype("int")
341+
depths = np.ceil(logn(_shape[valid], _scale_factors[valid])).astype("int")
343342
else:
344-
lg = logn(shape, scale_factors)
345-
depths = np.floor(logn(shape, scale_factors)).astype("int")
343+
depths = np.floor(logn(_shape[valid], _scale_factors[valid])).astype("int")
346344
result = min(depths)
347345
return result
348346

tests/test_multiscale.py

Lines changed: 91 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
multiscale,
1010
get_downscale_depth,
1111
normalize_chunks,
12-
ensure_minimum_chunks
12+
ensure_minimum_chunks,
13+
)
14+
from xarray_multiscale.reducers import (
15+
reshape_with_windows,
16+
windowed_mean,
17+
windowed_mode,
1318
)
14-
from xarray_multiscale.reducers import reshape_with_windows, windowed_mean, windowed_mode
1519
import dask.array as da
1620
import numpy as np
1721
from xarray import DataArray
@@ -20,6 +24,8 @@
2024

2125
def test_downscale_depth():
2226
assert get_downscale_depth((1,), (1,)) == 0
27+
assert get_downscale_depth((2,), (3,)) == 0
28+
assert get_downscale_depth((2, 1), (2, 1)) == 1
2329
assert get_downscale_depth((2, 2, 2), (2, 2, 2)) == 1
2430
assert get_downscale_depth((1, 2, 2), (2, 2, 2)) == 0
2531
assert get_downscale_depth((4, 4, 4), (2, 2, 2)) == 2
@@ -31,72 +37,88 @@ def test_downscale_depth():
3137
assert get_downscale_depth((1500, 5495, 5200), (2, 2, 2)) == 10
3238

3339

34-
@pytest.mark.parametrize(("size", "scale"), ((10, 2), (11, 2), ((10,11), (2,3))))
40+
@pytest.mark.parametrize(("size", "scale"), ((10, 2), (11, 2), ((10, 11), (2, 3))))
3541
def test_adjust_shape(size, scale):
3642
arr = DataArray(np.zeros(size))
3743
padded = adjust_shape(arr, scale, mode="constant")
3844
scale_array = np.array(scale)
3945
old_shape_array = np.array(arr.shape)
4046
new_shape_array = np.array(padded.shape)
41-
47+
4248
if np.all((old_shape_array % scale_array) == 0):
4349
assert np.array_equal(new_shape_array, old_shape_array)
4450
else:
45-
assert np.array_equal(new_shape_array, old_shape_array + ((scale_array - (old_shape_array % scale_array))))
51+
assert np.array_equal(
52+
new_shape_array,
53+
old_shape_array + ((scale_array - (old_shape_array % scale_array))),
54+
)
4655

4756
cropped = adjust_shape(arr, scale, mode="crop")
4857
new_shape_array = np.array(cropped.shape)
4958
if np.all((old_shape_array % scale_array) == 0):
5059
assert np.array_equal(new_shape_array, old_shape_array)
5160
else:
52-
assert np.array_equal(new_shape_array, old_shape_array - (old_shape_array % scale_array))
61+
assert np.array_equal(
62+
new_shape_array, old_shape_array - (old_shape_array % scale_array)
63+
)
64+
5365

5466
def test_downscale_2d():
5567
chunks = (2, 2)
5668
scale = (2, 1)
5769

58-
data = DataArray(da.from_array(np.array(
59-
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype="uint8"
60-
), chunks=chunks))
70+
data = DataArray(
71+
da.from_array(
72+
np.array(
73+
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype="uint8"
74+
),
75+
chunks=chunks,
76+
)
77+
)
6178
answer = DataArray(np.array([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]]))
62-
downscaled = downscale(data, windowed_mean, scale, pad_mode='crop').compute()
79+
downscaled = downscale(data, windowed_mean, scale, pad_mode="crop").compute()
6380
assert np.array_equal(downscaled, answer)
6481

6582

6683
def test_downscale_coords():
67-
data = DataArray(np.zeros((10, 10)), dims=('x','y'), coords={'x': np.arange(10)})
68-
scale_factors = (2,1)
84+
data = DataArray(np.zeros((10, 10)), dims=("x", "y"), coords={"x": np.arange(10)})
85+
scale_factors = (2, 1)
6986
downscaled = downscale_coords(data, scale_factors)
70-
answer = {'x': data['x'].coarsen({'x' : scale_factors[0]}).mean()}
71-
87+
answer = {"x": data["x"].coarsen({"x": scale_factors[0]}).mean()}
88+
7289
assert downscaled.keys() == answer.keys()
7390
for k in downscaled:
7491
assert_equal(answer[k], downscaled[k])
7592

76-
data = DataArray(np.zeros((10, 10)),
77-
dims=('x','y'),
78-
coords={'x': np.arange(10),
79-
'y': 5 + np.arange(10)})
80-
scale_factors = (2,1)
93+
data = DataArray(
94+
np.zeros((10, 10)),
95+
dims=("x", "y"),
96+
coords={"x": np.arange(10), "y": 5 + np.arange(10)},
97+
)
98+
scale_factors = (2, 1)
8199
downscaled = downscale_coords(data, scale_factors)
82-
answer = {'x': data['x'].coarsen({'x' : scale_factors[0]}).mean(),
83-
'y' : data['y'].coarsen({'y' : scale_factors[1]}).mean()}
84-
100+
answer = {
101+
"x": data["x"].coarsen({"x": scale_factors[0]}).mean(),
102+
"y": data["y"].coarsen({"y": scale_factors[1]}).mean(),
103+
}
104+
85105
assert downscaled.keys() == answer.keys()
86106
for k in downscaled:
87107
assert_equal(answer[k], downscaled[k])
88108

89-
data = DataArray(np.zeros((10, 10)),
90-
dims=('x','y'),
91-
coords={'x': np.arange(10),
92-
'y': 5 + np.arange(10),
93-
'foo' : 5})
94-
scale_factors = (2,2)
109+
data = DataArray(
110+
np.zeros((10, 10)),
111+
dims=("x", "y"),
112+
coords={"x": np.arange(10), "y": 5 + np.arange(10), "foo": 5},
113+
)
114+
scale_factors = (2, 2)
95115
downscaled = downscale_coords(data, scale_factors)
96-
answer = {'x': data['x'].coarsen({'x' : scale_factors[0]}).mean(),
97-
'y' : data['y'].coarsen({'y' : scale_factors[1]}).mean(),
98-
'foo': data['foo']}
99-
116+
answer = {
117+
"x": data["x"].coarsen({"x": scale_factors[0]}).mean(),
118+
"y": data["y"].coarsen({"y": scale_factors[1]}).mean(),
119+
"foo": data["foo"],
120+
}
121+
100122
assert downscaled.keys() == answer.keys()
101123
for k in downscaled:
102124
assert_equal(answer[k], downscaled[k])
@@ -106,7 +128,7 @@ def test_invalid_multiscale():
106128
with pytest.raises(ValueError):
107129
downscale_dask(np.arange(10), windowed_mean, (3,))
108130
with pytest.raises(ValueError):
109-
downscale_dask(np.arange(16).reshape(4,4), windowed_mean, (3,3))
131+
downscale_dask(np.arange(16).reshape(4, 4), windowed_mean, (3, 3))
110132

111133

112134
def test_multiscale():
@@ -163,14 +185,20 @@ def test_chunking():
163185
assert m.data.chunksize == chunks or m.data.chunksize == m.data.shape
164186

165187
chunks = (3,) * ndim
166-
multi = multiscale(base_array, reducer, 2, chunks=chunks, chunk_mode='minimum')
188+
multi = multiscale(base_array, reducer, 2, chunks=chunks, chunk_mode="minimum")
167189
for m in multi:
168-
assert np.greater_equal(m.data.chunksize, chunks).all() or m.data.chunksize == m.data.shape
190+
assert (
191+
np.greater_equal(m.data.chunksize, chunks).all()
192+
or m.data.chunksize == m.data.shape
193+
)
169194

170195
chunks = 3
171-
multi = multiscale(base_array, reducer, 2, chunks=chunks, chunk_mode='minimum')
196+
multi = multiscale(base_array, reducer, 2, chunks=chunks, chunk_mode="minimum")
172197
for m in multi:
173-
assert np.greater_equal(m.data.chunksize, (chunks,) * ndim).all() or m.data.chunksize == m.data.shape
198+
assert (
199+
np.greater_equal(m.data.chunksize, (chunks,) * ndim).all()
200+
or m.data.chunksize == m.data.shape
201+
)
174202

175203

176204
def test_depth():
@@ -182,16 +210,16 @@ def test_depth():
182210
assert len(full) == 5
183211

184212
partial = multiscale(base_array, reducer, 2, depth=-2)
185-
assert len(partial) == len(full) - 1
186-
[assert_equal(a,b) for a,b in zip(full, partial)]
213+
assert len(partial) == len(full) - 1
214+
[assert_equal(a, b) for a, b in zip(full, partial)]
187215

188216
partial = multiscale(base_array, reducer, 2, depth=2)
189-
assert len(partial) == 3
190-
[assert_equal(a,b) for a,b in zip(full, partial)]
217+
assert len(partial) == 3
218+
[assert_equal(a, b) for a, b in zip(full, partial)]
191219

192220
partial = multiscale(base_array, reducer, 2, depth=0)
193-
assert len(partial) == 1
194-
[assert_equal(a,b) for a,b in zip(full, partial)]
221+
assert len(partial) == 1
222+
[assert_equal(a, b) for a, b in zip(full, partial)]
195223

196224

197225
def test_coords():
@@ -215,23 +243,23 @@ def test_coords():
215243

216244

217245
def test_normalize_chunks():
218-
data = DataArray(da.zeros((4,6), chunks=(1,1)))
219-
assert normalize_chunks(data, {'dim_0' : 2, 'dim_1' : 1}) == (2,1)
246+
data = DataArray(da.zeros((4, 6), chunks=(1, 1)))
247+
assert normalize_chunks(data, {"dim_0": 2, "dim_1": 1}) == (2, 1)
220248

221249

222250
def test_ensure_minimum_chunks():
223-
data = da.zeros((4,6), chunks=(1,1))
224-
assert ensure_minimum_chunks(data, (2,2)) == (2,2)
251+
data = da.zeros((4, 6), chunks=(1, 1))
252+
assert ensure_minimum_chunks(data, (2, 2)) == (2, 2)
225253

226-
data = da.zeros((4,6), chunks=(4,1))
227-
assert ensure_minimum_chunks(data, (2,2)) == (4,2)
254+
data = da.zeros((4, 6), chunks=(4, 1))
255+
assert ensure_minimum_chunks(data, (2, 2)) == (4, 2)
228256

229257

230258
def test_broadcast_to_rank():
231259
assert broadcast_to_rank(2, 1) == (2,)
232-
assert broadcast_to_rank(2, 2) == (2,2)
233-
assert broadcast_to_rank((2,3), 2) == (2,3)
234-
assert broadcast_to_rank({0 : 2}, 3) == (2,1,1)
260+
assert broadcast_to_rank(2, 2) == (2, 2)
261+
assert broadcast_to_rank((2, 3), 2) == (2, 3)
262+
assert broadcast_to_rank({0: 2}, 3) == (2, 1, 1)
235263

236264

237265
def test_align_chunks():
@@ -243,11 +271,18 @@ def test_align_chunks():
243271
rechunked = align_chunks(data, scale_factors=(2,))
244272
assert rechunked.chunks == ((2,) * 5,)
245273

246-
data = da.arange(10, chunks=(1,1,3,5))
274+
data = da.arange(10, chunks=(1, 1, 3, 5))
247275
rechunked = align_chunks(data, scale_factors=(2,))
248-
assert rechunked.chunks == ((2, 2, 2, 4,),)
276+
assert rechunked.chunks == (
277+
(
278+
2,
279+
2,
280+
2,
281+
4,
282+
),
283+
)
249284

250285

251286
def test_reshape_with_windows():
252-
data = np.arange(36).reshape(6,6)
253-
assert reshape_with_windows(data, (2,2)).shape == (3,2,3,2)
287+
data = np.arange(36).reshape(6, 6)
288+
assert reshape_with_windows(data, (2, 2)).shape == (3, 2, 3, 2)

tests/test_reducers.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,18 @@ def test_windowed_mode():
88
results = windowed_mode(data, (4,))
99
assert np.array_equal(results, answer)
1010

11-
data = np.arange(16).reshape(4,4) % 3
12-
answer = np.array([[1,0],[0,2]])
13-
results = windowed_mode(data, (2,2))
11+
data = np.arange(16).reshape(4, 4) % 3
12+
answer = np.array([[1, 0], [0, 2]])
13+
results = windowed_mode(data, (2, 2))
1414
assert np.array_equal(results, answer)
1515

16+
1617
def test_windowed_mean():
17-
data = np.arange(16).reshape(4,4) % 2
18-
answer = np.array([[0.5, 0.5],[0.5, 0.5]])
19-
results = windowed_mean(data, (2,2))
20-
assert np.array_equal(results, answer)
18+
data = np.arange(16).reshape(4, 4) % 2
19+
answer = np.array([[0.5, 0.5], [0.5, 0.5]])
20+
results = windowed_mean(data, (2, 2))
21+
assert np.array_equal(results, answer)
22+
23+
data = np.arange(16).reshape(4, 4, 1) % 2
24+
answer = np.array([[0.5, 0.5], [0.5, 0.5]]).reshape((2,2,1))
25+
results = windowed_mean(data, (2, 2, 1))

0 commit comments

Comments
 (0)