Skip to content

Commit 9e7b1cf

Browse files
committed
Update rechunk task & add tests
1 parent b5b6cca commit 9e7b1cf

File tree

4 files changed

+296
-75
lines changed

4 files changed

+296
-75
lines changed

src/fractal_helper_tasks/drop_t_dimension.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def drop_t_dimension(
5353
(standard argument for Fractal tasks, managed by Fractal server).
5454
suffix: Suffix to be used for the new Zarr image. If overwrite_input
5555
is True, this file is only temporary.
56-
overwrite_input: Whether
56+
overwrite_input: Whether the existing iamge should be overwritten with
57+
the new OME-Zarr without the T dimension.
5758
"""
5859
# Normalize zarr_url
5960
zarr_url_old = zarr_url.rstrip("/")

src/fractal_helper_tasks/rechunk_zarr.py

Lines changed: 42 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 (C) BioVisionCenter, University of Zurich
1+
# Copyright 2025 (C) BioVisionCenter, University of Zurich
22
#
33
# Original authors:
44
# Joel Lüthi <[email protected]>
@@ -10,9 +10,10 @@
1010
from typing import Any, Optional
1111

1212
import ngio
13-
from ngio.core.utils import create_empty_ome_zarr_label
1413
from pydantic import validate_call
1514

15+
from fractal_helper_tasks.utils import normalize_chunk_size_dict, rechunk_label
16+
1617
logger = logging.getLogger(__name__)
1718

1819

@@ -22,6 +23,7 @@ def rechunk_zarr(
2223
zarr_url: str,
2324
chunk_sizes: Optional[dict[str, Optional[int]]] = None,
2425
suffix: str = "rechunked",
26+
rechunk_labels: bool = True,
2527
rebuild_pyramids: bool = True,
2628
overwrite_input: bool = True,
2729
overwrite: bool = False,
@@ -38,6 +40,8 @@ def rechunk_zarr(
3840
sizes. {"z": 10} will just change the Z chunking while keeping
3941
all other chunk sizes the same as the input.
4042
suffix: Suffix of the rechunked image.
43+
rechunk_labels: Whether to apply the same rechunking to all label
44+
images of the OME-Zarr as well.
4145
rebuild_pyramids: Whether pyramids are built fresh in the rechunked
4246
image. This has a small performance overhead, but ensures that
4347
this task is save against off-by-one issues when pyramid levels
@@ -47,11 +51,9 @@ def rechunk_zarr(
4751
overwrite: Whether to overwrite potential pre-existing output with the
4852
name zarr_url_suffix.
4953
"""
50-
chunk_sizes = chunk_sizes or {}
51-
valid_axes = ["t", "c", "z", "y", "x"]
52-
for axis in valid_axes:
53-
if axis not in chunk_sizes:
54-
chunk_sizes[axis] = None
54+
logger.info(f"Running `rechunk_zarr` on {zarr_url=} with {chunk_sizes=}.")
55+
56+
chunk_sizes = normalize_chunk_size_dict(chunk_sizes)
5557

5658
rechunked_zarr_url = zarr_url + f"_{suffix}"
5759
ngff_image = ngio.NgffImage(zarr_url)
@@ -62,19 +64,27 @@ def rechunk_zarr(
6264

6365
# Compute the chunksize tuple
6466
new_chunksize = [c[0] for c in chunks]
67+
logger.info(f"Initial chunk sizes were: {chunks}")
6568
# Overwrite chunk_size with user-set chunksize
6669
for i, axis in enumerate(axes_names):
6770
if axis in chunk_sizes:
6871
if chunk_sizes[axis] is not None:
6972
new_chunksize[i] = chunk_sizes[axis]
7073

71-
# TODO: Check for extra axes specified
74+
for axis in chunk_sizes:
75+
if axis not in axes_names:
76+
raise NotImplementedError(
77+
f"Rechunking with {axis=} is specified, but the OME-Zarr only "
78+
f"has the following axes: {axes_names}"
79+
)
80+
81+
logger.info(f"Chunk sizes after rechunking will be: {new_chunksize=}")
7282

7383
new_ngff_image = ngff_image.derive_new_image(
7484
store=rechunked_zarr_url,
7585
name=ngff_image.image_meta.name,
7686
overwrite=overwrite,
77-
copy_labels=False, # Copy if rechunk labels is not selected?
87+
copy_labels=not rechunk_labels,
7888
copy_tables=True,
7989
chunks=new_chunksize,
8090
)
@@ -93,79 +103,37 @@ def rechunk_zarr(
93103
ngff_image.get_image(path=path).on_disk_dask_array
94104
)
95105

96-
# Copy labels: Loop over them
97-
# Labels don't have a channel dimension
98-
chunk_sizes["c"] = None
99-
label_names = ngff_image.labels.list()
100-
for label in label_names:
101-
old_label = ngff_image.labels.get_label(name=label)
102-
label_level_paths = ngff_image.labels.levels_paths(name=label)
103-
# Compute the chunksize tuple
104-
chunks = old_label.on_disk_dask_array.chunks
105-
new_chunksize = [c[0] for c in chunks]
106-
# Overwrite chunk_size with user-set chunksize
107-
for i, axis in enumerate(old_label.dataset.on_disk_axes_names):
108-
if axis in chunk_sizes:
109-
if chunk_sizes[axis] is not None:
110-
new_chunksize[i] = chunk_sizes[axis]
111-
create_empty_ome_zarr_label(
112-
store=new_ngff_image.store
113-
+ "/"
114-
+ "labels"
115-
+ "/"
116-
+ label, # FIXME: Set this better?
117-
on_disk_shape=old_label.on_disk_shape,
118-
chunks=new_chunksize,
119-
dtype=old_label.on_disk_dask_array.dtype,
120-
on_disk_axis=old_label.dataset.on_disk_axes_names,
121-
pixel_sizes=old_label.dataset.pixel_size,
122-
xy_scaling_factor=old_label.metadata.xy_scaling_factor,
123-
z_scaling_factor=old_label.metadata.z_scaling_factor,
124-
time_spacing=old_label.dataset.time_spacing,
125-
time_units=old_label.dataset.time_axis_unit,
126-
levels=label_level_paths,
127-
name=label,
128-
overwrite=overwrite,
129-
version=old_label.metadata.version,
130-
)
131-
132-
# Fill in labels .attrs to contain the label name
133-
list_of_labels = new_ngff_image.labels.list()
134-
if label not in list_of_labels:
135-
new_ngff_image.labels._label_group.attrs["labels"] = [
136-
*list_of_labels,
137-
label,
138-
]
139-
140-
if rebuild_pyramids:
141-
# Set the highest resolution, then consolidate to build a new pyramid
142-
new_ngff_image.labels.get_label(
143-
name=label, highest_resolution=True
144-
).set_array(
145-
ngff_image.labels.get_label(
146-
name=label, highest_resolution=True
147-
).on_disk_dask_array
106+
# Copy labels
107+
if rechunk_labels:
108+
chunk_sizes["c"] = None
109+
label_names = ngff_image.labels.list()
110+
for label in label_names:
111+
rechunk_label(
112+
orig_ngff_image=ngff_image,
113+
new_ngff_image=new_ngff_image,
114+
label=label,
115+
chunk_sizes=chunk_sizes,
116+
overwrite=overwrite,
117+
rebuild_pyramids=rebuild_pyramids,
148118
)
149-
new_ngff_image.labels.get_label(
150-
name=label, highest_resolution=True
151-
).consolidate()
152-
else:
153-
for label_path in label_level_paths:
154-
new_ngff_image.labels.get_label(name=label, path=label_path).set_array(
155-
ngff_image.labels.get_label(
156-
name=label, path=label_path
157-
).on_disk_dask_array
158-
)
119+
159120
if overwrite_input:
160121
os.rename(zarr_url, f"{zarr_url}_tmp")
161122
os.rename(rechunked_zarr_url, zarr_url)
162123
shutil.rmtree(f"{zarr_url}_tmp")
163124
return
164125
else:
165-
image_list_updates = dict(
166-
image_list_updates=[dict(zarr_url=rechunked_zarr_url, origin=zarr_url)]
126+
output = dict(
127+
image_list_updates=[
128+
dict(
129+
zarr_url=rechunked_zarr_url,
130+
origin=zarr_url,
131+
types=dict(rechunked=True),
132+
)
133+
],
134+
filters=dict(types=dict(rechunked=True)),
167135
)
168-
return image_list_updates
136+
return output
169137

170138

171139
if __name__ == "__main__":

src/fractal_helper_tasks/utils.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2025 (C) BioVisionCenter, University of Zurich
2+
#
3+
# Original authors:
4+
# Joel Lüthi <[email protected]>
5+
"""Utils for helper tasks."""
6+
7+
from typing import Optional
8+
9+
import ngio
10+
from ngio.core.utils import create_empty_ome_zarr_label
11+
12+
13+
def normalize_chunk_size_dict(chunk_sizes: dict[str, Optional[int]]):
14+
"""Converts all chunk_size axes names to lower case and assert validity.
15+
16+
Args:
17+
chunk_sizes: Dictionary of chunk sizes that should be adapted. Can
18+
contain new chunk sizes for t, c, z, y & x.
19+
20+
Returns:
21+
chunk_sizes_norm: Normalized chunk_sizes dict.
22+
"""
23+
chunk_sizes = chunk_sizes or {}
24+
chunk_sizes_norm = {}
25+
for key, value in chunk_sizes.items():
26+
chunk_sizes_norm[key.lower()] = value
27+
28+
valid_axes = ["t", "c", "z", "y", "x"]
29+
for axis in chunk_sizes_norm:
30+
if axis not in valid_axes:
31+
raise ValueError(
32+
f"Axis {axis} is not supported. Valid axes choices are "
33+
f"{valid_axes}."
34+
)
35+
return chunk_sizes_norm
36+
37+
38+
def rechunk_label(
39+
orig_ngff_image: ngio.NgffImage,
40+
new_ngff_image: ngio.NgffImage,
41+
label: str,
42+
chunk_sizes: list[int],
43+
overwrite: bool = False,
44+
rebuild_pyramids: bool = True,
45+
):
46+
"""Saves a rechunked label image into a new OME-Zarr
47+
48+
The label image is based on an existing label image in another OME-Zarr.
49+
50+
Args:
51+
orig_ngff_image: Original OME-Zarr that contains the label image
52+
new_ngff_image: OME-Zarr to which the rechunked label image should be
53+
added.
54+
label: Name of the label image.
55+
chunk_sizes: New chunk sizes that should be applied
56+
overwrite: Whether the label image in `new_ngff_image` should be
57+
overwritten if it already exists.
58+
rebuild_pyramids: Whether pyramids are built fresh in the rechunked
59+
label image. This has a small performance overhead, but ensures
60+
that this task is save against off-by-one issues when pyramid
61+
levels aren't easily downsampled by 2.
62+
"""
63+
old_label = orig_ngff_image.labels.get_label(name=label)
64+
label_level_paths = orig_ngff_image.labels.levels_paths(name=label)
65+
# Compute the chunksize tuple
66+
chunks = old_label.on_disk_dask_array.chunks
67+
new_chunksize = [c[0] for c in chunks]
68+
# Overwrite chunk_size with user-set chunksize
69+
for i, axis in enumerate(old_label.dataset.on_disk_axes_names):
70+
if axis in chunk_sizes:
71+
if chunk_sizes[axis] is not None:
72+
new_chunksize[i] = chunk_sizes[axis]
73+
create_empty_ome_zarr_label(
74+
store=new_ngff_image.store
75+
+ "/"
76+
+ "labels"
77+
+ "/"
78+
+ label, # FIXME: Set this better?
79+
on_disk_shape=old_label.on_disk_shape,
80+
chunks=new_chunksize,
81+
dtype=old_label.on_disk_dask_array.dtype,
82+
on_disk_axis=old_label.dataset.on_disk_axes_names,
83+
pixel_sizes=old_label.dataset.pixel_size,
84+
xy_scaling_factor=old_label.metadata.xy_scaling_factor,
85+
z_scaling_factor=old_label.metadata.z_scaling_factor,
86+
time_spacing=old_label.dataset.time_spacing,
87+
time_units=old_label.dataset.time_axis_unit,
88+
levels=label_level_paths,
89+
name=label,
90+
overwrite=overwrite,
91+
version=old_label.metadata.version,
92+
)
93+
94+
# Fill in labels .attrs to contain the label name
95+
list_of_labels = new_ngff_image.labels.list()
96+
if label not in list_of_labels:
97+
new_ngff_image.labels._label_group.attrs["labels"] = [
98+
*list_of_labels,
99+
label,
100+
]
101+
102+
if rebuild_pyramids:
103+
# Set the highest resolution, then consolidate to build a new pyramid
104+
new_ngff_image.labels.get_label(name=label, highest_resolution=True).set_array(
105+
orig_ngff_image.labels.get_label(
106+
name=label, highest_resolution=True
107+
).on_disk_dask_array
108+
)
109+
new_ngff_image.labels.get_label(
110+
name=label, highest_resolution=True
111+
).consolidate()
112+
else:
113+
for label_path in label_level_paths:
114+
new_ngff_image.labels.get_label(name=label, path=label_path).set_array(
115+
orig_ngff_image.labels.get_label(
116+
name=label, path=label_path
117+
).on_disk_dask_array
118+
)

0 commit comments

Comments
 (0)