@@ -54,33 +54,20 @@ def test_downscale_2d():
5454 arr_xarray = DataArray (arr_dask )
5555
5656 downscaled_numpy_float = downscale (
57- arr_numpy , np .mean , scale , preserve_dtype = False
58- ).compute ()
57+ arr_numpy , np .mean , scale ).compute ()
5958
6059 downscaled_dask_float = downscale (
61- arr_dask , np .mean , scale , preserve_dtype = False
62- ).compute ()
60+ arr_dask , np .mean , scale ).compute ()
6361
6462 downscaled_xarray_float = downscale (
65- arr_xarray , np .mean , scale , preserve_dtype = False
66- ).compute ()
63+ arr_xarray , np .mean , scale ).compute ()
6764
6865 answer_float = np .array ([[0.5 , 0.5 , 0.5 , 0.5 ], [0.5 , 0.5 , 0.5 , 0.5 ]])
66+
6967 assert np .array_equal (downscaled_numpy_float , answer_float )
7068 assert np .array_equal (downscaled_dask_float , answer_float )
7169 assert np .array_equal (downscaled_xarray_float , answer_float )
7270
73- downscaled_numpy_int = downscale (
74- arr_numpy , np .mean , scale , dtype = arr_numpy .dtype
75- ).compute ()
76- downscaled_dask_int = downscale (
77- arr_dask , np .mean , scale , dtype = arr_numpy .dtype
78- ).compute ()
79-
80- answer_int = answer_float .astype ("int" )
81- assert np .array_equal (downscaled_numpy_int , answer_int )
82- assert np .array_equal (downscaled_dask_int , answer_int )
83-
8471
8572def test_multiscale ():
8673 ndim = 3
@@ -94,7 +81,7 @@ def test_multiscale():
9481
9582 pyr_trimmed = multiscale (array , np .mean , 2 , pad_mode = None )
9683 pyr_padded = multiscale (array , np .mean , 2 , pad_mode = "reflect" )
97-
84+ pyr_trimmed_recursive = multiscale ( array , np . mean , 2 , pad_mode = None , recursive = True )
9885 assert [p .shape for p in pyr_padded ] == [
9986 shape ,
10087 (5 , 5 , 5 ),
@@ -111,4 +98,7 @@ def test_multiscale():
11198 assert np .array_equal (
11299 pyr_trimmed [- 2 ].data .mean ().compute (), pyr_trimmed [- 1 ].data .compute ().mean ()
113100 )
101+ assert np .array_equal (
102+ pyr_trimmed_recursive [- 2 ].data .mean ().compute (), pyr_trimmed_recursive [- 1 ].data .compute ().mean ()
103+ )
114104 assert np .allclose (pyr_padded [0 ].data .mean ().compute (), 0.17146776406035666 )
0 commit comments