|
4 | 4 | from xarray import DataArray |
5 | 5 | from xarray.testing import assert_equal |
6 | 6 |
|
7 | | -from xarray_multiscale.multiscale import (adjust_shape, downsampling_depth, |
8 | | - downscale, downscale_coords, |
9 | | - downscale_dask, multiscale) |
| 7 | +from xarray_multiscale.multiscale import ( |
| 8 | + adjust_shape, |
| 9 | + downsampling_depth, |
| 10 | + downscale, |
| 11 | + downscale_coords, |
| 12 | + downscale_dask, |
| 13 | + multiscale, |
| 14 | +) |
10 | 15 | from xarray_multiscale.reducers import windowed_mean |
11 | 16 |
|
12 | 17 |
|
@@ -130,33 +135,54 @@ def test_multiscale(ndim: int, chained: bool): |
130 | 135 |
|
131 | 136 | def test_chunking(): |
132 | 137 | ndim = 3 |
133 | | - shape = (9,) * ndim |
134 | | - base_array = da.zeros(shape, chunks=(1,) * ndim) |
135 | | - chunks = (1,) * ndim |
| 138 | + shape = (16,) * ndim |
| 139 | + chunks = (4,) * ndim |
| 140 | + base_array = da.zeros(shape, chunks=chunks) |
136 | 141 | reducer = windowed_mean |
137 | | - multi = multiscale(base_array, reducer, 2, chunks=chunks) |
138 | | - assert all([m.data.chunksize == chunks for m in multi]) |
139 | | - |
140 | | - chunks = (3,) * ndim |
141 | | - multi = multiscale(base_array, reducer, 2, chunks=chunks) |
142 | | - for m in multi: |
143 | | - assert m.data.chunksize == chunks or m.data.chunksize == m.data.shape |
144 | | - |
145 | | - chunks = (3,) * ndim |
146 | | - multi = multiscale(base_array, reducer, 2, chunks=chunks) |
147 | | - for m in multi: |
148 | | - assert ( |
149 | | - np.greater_equal(m.data.chunksize, chunks).all() |
150 | | - or m.data.chunksize == m.data.shape |
151 | | - ) |
| 142 | + scale_factors = (2,) * ndim |
| 143 | + |
| 144 | + multi = multiscale(base_array, reducer, scale_factors) |
| 145 | + expected_chunks = [ |
| 146 | + np.floor_divide(chunks, [s**idx for s in scale_factors]) |
| 147 | + for idx, m in enumerate(multi) |
| 148 | + ] |
| 149 | + expected_chunks = [ |
| 150 | + x |
| 151 | + if np.all(x) |
| 152 | + else [ |
| 153 | + 1, |
| 154 | + ] |
| 155 | + * ndim |
| 156 | + for x in expected_chunks |
| 157 | + ] |
| 158 | + assert all( |
| 159 | + [np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)] |
| 160 | + ) |
| 161 | + |
| 162 | + multi = multiscale(base_array, reducer, scale_factors, chunks=chunks) |
| 163 | + expected_chunks = [ |
| 164 | + chunks if np.greater(m.shape, chunks).all() else m.shape |
| 165 | + for idx, m in enumerate(multi) |
| 166 | + ] |
| 167 | + assert all( |
| 168 | + [np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)] |
| 169 | + ) |
| 170 | + |
| 171 | + chunks = (3, -1, -1) |
| 172 | + multi = multiscale(base_array, reducer, scale_factors, chunks=chunks) |
| 173 | + expected_chunks = [ |
| 174 | + (min(chunks[0], m.shape[0]), m.shape[1], m.shape[2]) for m in multi |
| 175 | + ] |
| 176 | + assert all( |
| 177 | + [np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)] |
| 178 | + ) |
152 | 179 |
|
153 | 180 | chunks = 3 |
154 | | - multi = multiscale(base_array, reducer, 2, chunks=chunks) |
155 | | - for m in multi: |
156 | | - assert ( |
157 | | - np.greater_equal(m.data.chunksize, (chunks,) * ndim).all() |
158 | | - or m.data.chunksize == m.data.shape |
159 | | - ) |
| 181 | + multi = multiscale(base_array, reducer, scale_factors, chunks=chunks) |
| 182 | + expected_chunks = [tuple(min(chunks, s) for s in m.shape) for m in multi] |
| 183 | + assert all( |
| 184 | + [np.array_equal(m.data.chunksize, e) for m, e in zip(multi, expected_chunks)] |
| 185 | + ) |
160 | 186 |
|
161 | 187 |
|
162 | 188 | def test_coords(): |
|
0 commit comments