diff --git a/docs/tutorial/tutorial.py b/docs/tutorial/tutorial.py index b4bfc42..cd9e79c 100644 --- a/docs/tutorial/tutorial.py +++ b/docs/tutorial/tutorial.py @@ -15,6 +15,7 @@ import skimage.color import skimage.data import tifffile +import zarr from loguru import logger import stack_to_chunk @@ -99,11 +100,47 @@ # %% -# The levels property can be inspected to show we've added the first level. Ekach level -# is downsampled by a factor of ``2**level``, so level 0 is downsampled by a factor of -# 1, which is just a copy of the original data (as expected). +# The levels property can be inspected to show we've added the first level (level 0): print(group.levels) +# %% +# Next, we will add some downsampled levels to the group. This is done by calling +# ``add_downsample_level`` on the group object. Each level is linearly downsampled by a +# factor of :math:`2^{level}`, so level 0 is downsampled by a factor of 1 (original +# resolution), level 1 is downsampled by a factor of 2, level 2 by a factor of 4, +# and so on. + +group.add_downsample_level(1) +group.add_downsample_level(2) + +# %% +# We can see from the progress messages that the shape of each level is half the size of +# the previous level (rounded up to the nearest integer). The chunk sizes are maintained +# the same as in the original data (until the image size is less than the chunk size). + +# %% +# We can inspect the levels property again to check that levels 0, 1, and 2 are present: +print(group.levels) + +# %% +# Note that the downampled levels have to be added in order, so we can't add level 3 +# before adding level 2 (because the previous level is needed to calculate the next). + +# %% +# Let's plot the downsampled data to see what it looks like. +# We turn off interpolation (which ``imshow`` does by default) to make the pixelation +# at downsampled levels more clearly visible. + +fig, ax = plt.subplots(1, 3, figsize=(12, 4)) +for i, level in enumerate(group.levels): + data = zarr.open(temp_dir_path / "chunked.zarr" / str(level)) + first_slice = data[:].transpose(2, 1, 0)[0] + ax[i].imshow(first_slice, cmap="gray", interpolation="none") + ax[i].set_title(f"Level {level}") + ax[i].axis("off") +fig.tight_layout() +fig.show() + # %% # Cleanup # ------- diff --git a/pyproject.toml b/pyproject.toml index c004b08..9d43e1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ classifiers = [ dependencies = [ "dask", "loguru", + "scikit-image", "zarr", ] description = "Convert stacks of images to chunked datasets" diff --git a/ruff.toml b/ruff.toml index be68566..2e84e63 100644 --- a/ruff.toml +++ b/ruff.toml @@ -33,6 +33,7 @@ per-file-ignores = {"docs/*" = [ "D103", # Missing docstring in public function "INP001", # is part of an implicit namespace package "S101", + "SLF001", # Private member accessed ]} select = [ "ALL", diff --git a/src/stack_to_chunk/downsample.py b/src/stack_to_chunk/downsample.py new file mode 100644 index 0000000..eed26ea --- /dev/null +++ b/src/stack_to_chunk/downsample.py @@ -0,0 +1,82 @@ +""" +Utilities for downsampling images. + +These are based on the ``ome_zarr.dask_utils.py`` module of the ome-zarr-py library, +originally contributed by by Andreas Eisenbarth @aeisenbarth. +See https://github.com/toloudis/ome-zarr-py/pull/1 +""" + +import numpy as np +import skimage.transform +from dask import array as da + + +def _rechunk_to_even(image: da.Array) -> da.Array: + """ + Rechunk the input image so that chunk sizes are even in each dimension. + + This guarantees integer chunk sizes after downsampling by two. + """ + factors = np.array([0.5] * image.ndim) + even_chunksize = tuple( + np.maximum(1, np.round(np.array(image.chunksize) * factors) / factors).astype( + int + ) + ) + return image.rechunk(even_chunksize) + + +def _half_shape(input_shape: tuple[int, int, int]) -> tuple[int, int, int]: + """ + Calculate the output shape after downsampling by two in each dimension. + + Rounds up to the nearest integer after division. + """ + return tuple(np.ceil(np.array(input_shape) / 2).astype(int)) + + +def _resize_block(block: da.Array) -> da.Array: + """ + Resize a block by a factor of 2 in each dimension using linear interpolation. + """ + new_block_shape = _half_shape(block.shape) + return skimage.transform.resize( + block, + new_block_shape, + order=1, + anti_aliasing=False, + ).astype(block.dtype) + + +def downsample_by_two(image: da.Array) -> da.Array: + """ + Downsample a dask array by two in each dimension. + + Parameters + ---------- + image : da.Array + The input image. + + Returns + ------- + da.Array + The downsampled image, which has half the size of the input image in each + dimension. + + """ + new_image_shape = _half_shape(image.shape) + new_image_slices = tuple(slice(0, d) for d in new_image_shape) + + # Rechunk the image so that chunk sizes will be whole numbers after downsampling + image_rechunked = _rechunk_to_even(image) + new_image_chunksize = _half_shape(image_rechunked.chunksize) + + new_image = da.map_blocks( + _resize_block, + image_rechunked, + chunks=new_image_chunksize, + dtype=image.dtype, + )[new_image_slices] + + # restore the original chunking and type + return new_image.rechunk(image.chunksize).astype(image.dtype) diff --git a/src/stack_to_chunk/main.py b/src/stack_to_chunk/main.py index 820dfe8..932cf30 100644 --- a/src/stack_to_chunk/main.py +++ b/src/stack_to_chunk/main.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Any, Literal +import dask.array as da import numpy as np import zarr from dask.array.core import Array @@ -14,6 +15,7 @@ from numcodecs.abc import Codec from stack_to_chunk._array_helpers import _copy_slab +from stack_to_chunk.downsample import downsample_by_two from stack_to_chunk.ome_ngff import SPATIAL_UNIT @@ -99,7 +101,7 @@ def levels(self) -> list[int]: def add_full_res_data( self, - data: Array, + data: da.Array, *, chunk_size: int, compressor: Literal["default"] | Codec, @@ -178,18 +180,9 @@ def add_full_res_data( p.join() blosc.use_threads = blosc_use_threads + self._add_level_metadata(0) logger.info("Finished full resolution copy to zarr.") - multiscales = self._group.attrs["multiscales"] - multiscales[0]["datasets"].append( - { - "path": "0", - "coordinateTransformations": [{"type": "scale", "scale": [1, 1, 1]}], - } - ) - - self._group.attrs["multiscales"] = multiscales - def add_downsample_level(self, level: int) -> None: """ Add a level of downsampling. @@ -221,12 +214,55 @@ def add_downsample_level(self, level: int) -> None: msg, ) - source_data = self._group[level_minus_one] - new_shape = np.ceil(np.array(source_data.shape) / 2) + logger.info(f"Downsampling level {level_minus_one} to level {level_str}...") + # Get the source data from the level below as a dask array + source_store = self._group[level_minus_one] + source_data = da.from_zarr(source_store, chunks=source_store.chunks) - self._group[level_str] = zarr.create( - new_shape, - chunks=source_data.chunks, - dtype=source_data.dtype, - compressor=source_data.compressor, + # Linearly downsample the data by a factor of 2 in each dimension + new_data = downsample_by_two(source_data) + logger.info( + f"Generated level {level_str} array with shape {new_data.shape} " + f"and chunk sizes {new_data.chunksize}, using linear interpolation." ) + + # Create the new zarr store for the downsampled data + new_store = self._group.require_dataset( + level_str, + shape=new_data.shape, + chunks=source_store.chunks, + dtype=source_store.dtype, + compressor=source_store.compressor, + ) + # Write the downsampled data to the new store + new_data.to_zarr(new_store, compute=True) + self._add_level_metadata(level) + logger.info(f"Saved level {level_str} to zarr.") + + def _add_level_metadata(self, level: int = 0) -> None: + """ + Add the required multiscale metadata for the corresponding level. + + Parameters + ---------- + level : + Level of downsampling. Level 0 corresponds to full resolution data. + + """ + # we assume that the scale factor is always 2 in each dimension + scale_factors = [float(s * 2**level) for s in self._voxel_size] + new_dataset = { + "path": str(level), + "coordinateTransformations": [ + { + "type": "scale", + "scale": scale_factors, + } + ], + } + + multiscales = self._group.attrs["multiscales"][0] + existing_dataset_paths = [d["path"] for d in multiscales["datasets"]] + if new_dataset["path"] not in existing_dataset_paths: + multiscales["datasets"].append(new_dataset) + self._group.attrs["multiscales"] = [multiscales] diff --git a/src/stack_to_chunk/tests/test_downsample.py b/src/stack_to_chunk/tests/test_downsample.py new file mode 100644 index 0000000..7561f0f --- /dev/null +++ b/src/stack_to_chunk/tests/test_downsample.py @@ -0,0 +1,69 @@ +"""Tests for the downsample.py module.""" + +import dask.array as da +import numpy as np +import pytest +from skimage.transform import resize + +from stack_to_chunk.downsample import ( + _half_shape, + _rechunk_to_even, + _resize_block, + downsample_by_two, +) + + +class TestDownsample: + """Tests for the downsample.py module.""" + + shape_3d = tuple[int, int, int] + image_shape: shape_3d = (583, 245, 156) + image_chunksize: shape_3d = (64, 64, 64) + image_darray: da.Array = da.random.randint( + low=0, high=2**16, dtype=np.uint16, size=image_shape + ) + + @pytest.mark.parametrize( + "chunksize", + [ + (64, 64, 64), + (64, 63, 63), + (63, 63, 63), + ], + ) + def test_rechunk_to_even(self, chunksize: shape_3d) -> None: + """Test rechunking to even chunk sizes.""" + chunked_array = self.image_darray.rechunk(chunksize) + even_chunksize = _rechunk_to_even(chunked_array).chunksize + assert even_chunksize == self.image_chunksize + + def test_half_shape(self) -> None: + """Test calculating the half shape of an input shape.""" + assert _half_shape(self.image_shape) == (292, 123, 78) + assert _half_shape(self.image_chunksize) == (32, 32, 32) + + def test_resize_block(self) -> None: + """Test resizing a single block by a factor of 2 in each dimension.""" + block = self.image_darray[:64, :64, :64].compute() + resized_block = _resize_block(block) + assert resized_block.shape == (32, 32, 32) + assert resized_block.dtype == block.dtype + + def test_downsample_by_two(self) -> None: + """Test downsampling a chunked image by a factor of 2 in each dimension.""" + input_array = self.image_darray.rechunk(self.image_chunksize) + downsampled = downsample_by_two(input_array) + assert downsampled.chunksize == input_array.chunksize + assert downsampled.dtype == input_array.dtype + assert downsampled.ndim == input_array.ndim + assert downsampled.shape == _half_shape(input_array.shape) + + # directly downsample image (without parallelization) + directly_downsampled = resize( + input_array, + output_shape=_half_shape(input_array.shape), + order=1, + anti_aliasing=False, + ).astype(input_array.dtype) + + np.testing.assert_equal(downsampled.compute(), directly_downsampled) diff --git a/src/stack_to_chunk/tests/test_main.py b/src/stack_to_chunk/tests/test_main.py index e1967cb..26453f8 100644 --- a/src/stack_to_chunk/tests/test_main.py +++ b/src/stack_to_chunk/tests/test_main.py @@ -8,6 +8,7 @@ import numpy as np import pytest import zarr +from skimage.transform import resize from stack_to_chunk import MultiScaleGroup, memory_per_process @@ -34,6 +35,20 @@ def test_workflow(tmp_path: Path, arr: da.Array) -> None: compressor = numcodecs.blosc.Blosc(cname="zstd", clevel=2, shuffle=2) chunk_size = 64 + shape = (583, 245, 156) + arr = da.random.randint(low=0, high=2**16, dtype=np.uint16, size=shape) + arr = arr.rechunk(chunks=(shape[0], shape[1], 1)) + + with pytest.raises( + ValueError, + match="Input array is must have a chunk size of 1 in the third dimension.", + ): + group.add_full_res_data( + arr.rechunk(chunks=(shape[0], shape[1], 2)), + n_processes=2, + chunk_size=chunk_size, + compressor="default", + ) assert memory_per_process(arr, chunk_size=chunk_size) == 18282880 group.add_full_res_data( @@ -65,7 +80,7 @@ def test_workflow(tmp_path: Path, arr: da.Array) -> None: "datasets": [ { "coordinateTransformations": [ - {"scale": [1, 1, 1], "type": "scale"} + {"type": "scale", "scale": [3.0, 4.0, 5.0]} ], "path": "0", } @@ -93,8 +108,32 @@ def test_workflow(tmp_path: Path, arr: da.Array) -> None: compressor="default", ) + # Check that adding a downsample level works group.add_downsample_level(1) assert group.levels == [0, 1] + multiscales = group._group.attrs["multiscales"][0] + assert multiscales["datasets"] == [ + { + "path": "0", + "coordinateTransformations": [{"type": "scale", "scale": [3.0, 4.0, 5.0]}], + }, + { + "path": "1", + "coordinateTransformations": [{"type": "scale", "scale": [6.0, 8.0, 10.0]}], + }, + ] + zarr_arr_1 = zarr.open(zarr_path / "1") + shape_1 = (292, 123, 78) + assert zarr_arr_1.chunks == zarr_arr.chunks + assert zarr_arr_1.shape == shape_1 + assert zarr_arr_1.dtype == np.uint16 + + # The downsampled array should be equal to the original array downsampled + # directly with skimage.transform.resize (without chunking/parallelism) + directly_downsampled = resize(arr, shape_1, order=1, anti_aliasing=False).astype( + np.uint16 + ) + np.testing.assert_allclose(directly_downsampled[:], zarr_arr_1[:]) with pytest.raises(RuntimeError, match="Level 1 already found in zarr group"): group.add_downsample_level(1)