1- # Copyright 2024 (C) BioVisionCenter, University of Zurich
1+ # Copyright 2025 (C) BioVisionCenter, University of Zurich
22#
33# Original authors:
441010from typing import Any , Optional
1111
1212import ngio
13- from ngio .core .utils import create_empty_ome_zarr_label
1413from pydantic import validate_call
1514
15+ from fractal_helper_tasks .utils import normalize_chunk_size_dict , rechunk_label
16+
1617logger = logging .getLogger (__name__ )
1718
1819
@@ -22,6 +23,7 @@ def rechunk_zarr(
2223 zarr_url : str ,
2324 chunk_sizes : Optional [dict [str , Optional [int ]]] = None ,
2425 suffix : str = "rechunked" ,
26+ rechunk_labels : bool = True ,
2527 rebuild_pyramids : bool = True ,
2628 overwrite_input : bool = True ,
2729 overwrite : bool = False ,
@@ -38,6 +40,8 @@ def rechunk_zarr(
3840 sizes. {"z": 10} will just change the Z chunking while keeping
3941 all other chunk sizes the same as the input.
4042 suffix: Suffix of the rechunked image.
43+ rechunk_labels: Whether to apply the same rechunking to all label
44+ images of the OME-Zarr as well.
4145 rebuild_pyramids: Whether pyramids are built fresh in the rechunked
4246 image. This has a small performance overhead, but ensures that
4347 this task is save against off-by-one issues when pyramid levels
@@ -47,11 +51,9 @@ def rechunk_zarr(
4751 overwrite: Whether to overwrite potential pre-existing output with the
4852 name zarr_url_suffix.
4953 """
50- chunk_sizes = chunk_sizes or {}
51- valid_axes = ["t" , "c" , "z" , "y" , "x" ]
52- for axis in valid_axes :
53- if axis not in chunk_sizes :
54- chunk_sizes [axis ] = None
54+ logger .info (f"Running `rechunk_zarr` on { zarr_url = } with { chunk_sizes = } ." )
55+
56+ chunk_sizes = normalize_chunk_size_dict (chunk_sizes )
5557
5658 rechunked_zarr_url = zarr_url + f"_{ suffix } "
5759 ngff_image = ngio .NgffImage (zarr_url )
@@ -62,19 +64,27 @@ def rechunk_zarr(
6264
6365 # Compute the chunksize tuple
6466 new_chunksize = [c [0 ] for c in chunks ]
67+ logger .info (f"Initial chunk sizes were: { chunks } " )
6568 # Overwrite chunk_size with user-set chunksize
6669 for i , axis in enumerate (axes_names ):
6770 if axis in chunk_sizes :
6871 if chunk_sizes [axis ] is not None :
6972 new_chunksize [i ] = chunk_sizes [axis ]
7073
71- # TODO: Check for extra axes specified
74+ for axis in chunk_sizes :
75+ if axis not in axes_names :
76+ raise NotImplementedError (
77+ f"Rechunking with { axis = } is specified, but the OME-Zarr only "
78+ f"has the following axes: { axes_names } "
79+ )
80+
81+ logger .info (f"Chunk sizes after rechunking will be: { new_chunksize = } " )
7282
7383 new_ngff_image = ngff_image .derive_new_image (
7484 store = rechunked_zarr_url ,
7585 name = ngff_image .image_meta .name ,
7686 overwrite = overwrite ,
77- copy_labels = False , # Copy if rechunk labels is not selected?
87+ copy_labels = not rechunk_labels ,
7888 copy_tables = True ,
7989 chunks = new_chunksize ,
8090 )
@@ -93,79 +103,37 @@ def rechunk_zarr(
93103 ngff_image .get_image (path = path ).on_disk_dask_array
94104 )
95105
96- # Copy labels: Loop over them
97- # Labels don't have a channel dimension
98- chunk_sizes ["c" ] = None
99- label_names = ngff_image .labels .list ()
100- for label in label_names :
101- old_label = ngff_image .labels .get_label (name = label )
102- label_level_paths = ngff_image .labels .levels_paths (name = label )
103- # Compute the chunksize tuple
104- chunks = old_label .on_disk_dask_array .chunks
105- new_chunksize = [c [0 ] for c in chunks ]
106- # Overwrite chunk_size with user-set chunksize
107- for i , axis in enumerate (old_label .dataset .on_disk_axes_names ):
108- if axis in chunk_sizes :
109- if chunk_sizes [axis ] is not None :
110- new_chunksize [i ] = chunk_sizes [axis ]
111- create_empty_ome_zarr_label (
112- store = new_ngff_image .store
113- + "/"
114- + "labels"
115- + "/"
116- + label , # FIXME: Set this better?
117- on_disk_shape = old_label .on_disk_shape ,
118- chunks = new_chunksize ,
119- dtype = old_label .on_disk_dask_array .dtype ,
120- on_disk_axis = old_label .dataset .on_disk_axes_names ,
121- pixel_sizes = old_label .dataset .pixel_size ,
122- xy_scaling_factor = old_label .metadata .xy_scaling_factor ,
123- z_scaling_factor = old_label .metadata .z_scaling_factor ,
124- time_spacing = old_label .dataset .time_spacing ,
125- time_units = old_label .dataset .time_axis_unit ,
126- levels = label_level_paths ,
127- name = label ,
128- overwrite = overwrite ,
129- version = old_label .metadata .version ,
130- )
131-
132- # Fill in labels .attrs to contain the label name
133- list_of_labels = new_ngff_image .labels .list ()
134- if label not in list_of_labels :
135- new_ngff_image .labels ._label_group .attrs ["labels" ] = [
136- * list_of_labels ,
137- label ,
138- ]
139-
140- if rebuild_pyramids :
141- # Set the highest resolution, then consolidate to build a new pyramid
142- new_ngff_image .labels .get_label (
143- name = label , highest_resolution = True
144- ).set_array (
145- ngff_image .labels .get_label (
146- name = label , highest_resolution = True
147- ).on_disk_dask_array
106+ # Copy labels
107+ if rechunk_labels :
108+ chunk_sizes ["c" ] = None
109+ label_names = ngff_image .labels .list ()
110+ for label in label_names :
111+ rechunk_label (
112+ orig_ngff_image = ngff_image ,
113+ new_ngff_image = new_ngff_image ,
114+ label = label ,
115+ chunk_sizes = chunk_sizes ,
116+ overwrite = overwrite ,
117+ rebuild_pyramids = rebuild_pyramids ,
148118 )
149- new_ngff_image .labels .get_label (
150- name = label , highest_resolution = True
151- ).consolidate ()
152- else :
153- for label_path in label_level_paths :
154- new_ngff_image .labels .get_label (name = label , path = label_path ).set_array (
155- ngff_image .labels .get_label (
156- name = label , path = label_path
157- ).on_disk_dask_array
158- )
119+
159120 if overwrite_input :
160121 os .rename (zarr_url , f"{ zarr_url } _tmp" )
161122 os .rename (rechunked_zarr_url , zarr_url )
162123 shutil .rmtree (f"{ zarr_url } _tmp" )
163124 return
164125 else :
165- image_list_updates = dict (
166- image_list_updates = [dict (zarr_url = rechunked_zarr_url , origin = zarr_url )]
126+ output = dict (
127+ image_list_updates = [
128+ dict (
129+ zarr_url = rechunked_zarr_url ,
130+ origin = zarr_url ,
131+ types = dict (rechunked = True ),
132+ )
133+ ],
134+ filters = dict (types = dict (rechunked = True )),
167135 )
168- return image_list_updates
136+ return output
169137
170138
171139if __name__ == "__main__" :
0 commit comments