Skip to content

Commit b5b6cca

Browse files
committed
Add initial rechunking task
1 parent 438bf73 commit b5b6cca

File tree

2 files changed

+178
-1
lines changed

2 files changed

+178
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ authors = [
2727
# Required Python version and dependencies
2828
requires-python = ">=3.9"
2929
dependencies = [
30-
"fractal-tasks-core==1.3.4"
30+
"fractal-tasks-core==1.3.4","ngio==0.1.4",
3131
]
3232

3333
# Optional dependencies (e.g. for `pip install -e ".[dev]"`, see
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2024 (C) BioVisionCenter, University of Zurich
2+
#
3+
# Original authors:
4+
# Joel Lüthi <[email protected]>
5+
"""Rechunk an existing Zarr."""
6+
7+
import logging
8+
import os
9+
import shutil
10+
from typing import Any, Optional
11+
12+
import ngio
13+
from ngio.core.utils import create_empty_ome_zarr_label
14+
from pydantic import validate_call
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
@validate_call
20+
def rechunk_zarr(
21+
*,
22+
zarr_url: str,
23+
chunk_sizes: Optional[dict[str, Optional[int]]] = None,
24+
suffix: str = "rechunked",
25+
rebuild_pyramids: bool = True,
26+
overwrite_input: bool = True,
27+
overwrite: bool = False,
28+
) -> dict[str, Any]:
29+
"""Drops singleton t dimension.
30+
31+
Args:
32+
zarr_url: Path or url to the individual OME-Zarr image to be processed.
33+
(standard argument for Fractal tasks, managed by Fractal server).
34+
chunk_sizes: Dictionary of chunk sizes to adapt. One can set any of
35+
the t, c, z, y, x axes that exist in the input image to be resized
36+
to a different chunk size. For example, {"y": 4000, "x": 4000}
37+
will set a new x & y chunking while maintaining the other chunk
38+
sizes. {"z": 10} will just change the Z chunking while keeping
39+
all other chunk sizes the same as the input.
40+
suffix: Suffix of the rechunked image.
41+
rebuild_pyramids: Whether pyramids are built fresh in the rechunked
42+
image. This has a small performance overhead, but ensures that
43+
this task is save against off-by-one issues when pyramid levels
44+
aren't easily downsampled by 2.
45+
overwrite_input: Whether the old image without rechunking should be
46+
overwritten (to avoid duplicating the data needed).
47+
overwrite: Whether to overwrite potential pre-existing output with the
48+
name zarr_url_suffix.
49+
"""
50+
chunk_sizes = chunk_sizes or {}
51+
valid_axes = ["t", "c", "z", "y", "x"]
52+
for axis in valid_axes:
53+
if axis not in chunk_sizes:
54+
chunk_sizes[axis] = None
55+
56+
rechunked_zarr_url = zarr_url + f"_{suffix}"
57+
ngff_image = ngio.NgffImage(zarr_url)
58+
pyramid_paths = ngff_image.levels_paths
59+
highest_res_img = ngff_image.get_image()
60+
axes_names = highest_res_img.dataset.on_disk_axes_names
61+
chunks = highest_res_img.on_disk_dask_array.chunks
62+
63+
# Compute the chunksize tuple
64+
new_chunksize = [c[0] for c in chunks]
65+
# Overwrite chunk_size with user-set chunksize
66+
for i, axis in enumerate(axes_names):
67+
if axis in chunk_sizes:
68+
if chunk_sizes[axis] is not None:
69+
new_chunksize[i] = chunk_sizes[axis]
70+
71+
# TODO: Check for extra axes specified
72+
73+
new_ngff_image = ngff_image.derive_new_image(
74+
store=rechunked_zarr_url,
75+
name=ngff_image.image_meta.name,
76+
overwrite=overwrite,
77+
copy_labels=False, # Copy if rechunk labels is not selected?
78+
copy_tables=True,
79+
chunks=new_chunksize,
80+
)
81+
82+
ngff_image = ngio.NgffImage(zarr_url)
83+
84+
if rebuild_pyramids:
85+
# Set the highest resolution, then consolidate to build a new pyramid
86+
new_ngff_image.get_image(highest_resolution=True).set_array(
87+
ngff_image.get_image(highest_resolution=True).on_disk_dask_array
88+
)
89+
new_ngff_image.get_image(highest_resolution=True).consolidate()
90+
else:
91+
for path in pyramid_paths:
92+
new_ngff_image.get_image(path=path).set_array(
93+
ngff_image.get_image(path=path).on_disk_dask_array
94+
)
95+
96+
# Copy labels: Loop over them
97+
# Labels don't have a channel dimension
98+
chunk_sizes["c"] = None
99+
label_names = ngff_image.labels.list()
100+
for label in label_names:
101+
old_label = ngff_image.labels.get_label(name=label)
102+
label_level_paths = ngff_image.labels.levels_paths(name=label)
103+
# Compute the chunksize tuple
104+
chunks = old_label.on_disk_dask_array.chunks
105+
new_chunksize = [c[0] for c in chunks]
106+
# Overwrite chunk_size with user-set chunksize
107+
for i, axis in enumerate(old_label.dataset.on_disk_axes_names):
108+
if axis in chunk_sizes:
109+
if chunk_sizes[axis] is not None:
110+
new_chunksize[i] = chunk_sizes[axis]
111+
create_empty_ome_zarr_label(
112+
store=new_ngff_image.store
113+
+ "/"
114+
+ "labels"
115+
+ "/"
116+
+ label, # FIXME: Set this better?
117+
on_disk_shape=old_label.on_disk_shape,
118+
chunks=new_chunksize,
119+
dtype=old_label.on_disk_dask_array.dtype,
120+
on_disk_axis=old_label.dataset.on_disk_axes_names,
121+
pixel_sizes=old_label.dataset.pixel_size,
122+
xy_scaling_factor=old_label.metadata.xy_scaling_factor,
123+
z_scaling_factor=old_label.metadata.z_scaling_factor,
124+
time_spacing=old_label.dataset.time_spacing,
125+
time_units=old_label.dataset.time_axis_unit,
126+
levels=label_level_paths,
127+
name=label,
128+
overwrite=overwrite,
129+
version=old_label.metadata.version,
130+
)
131+
132+
# Fill in labels .attrs to contain the label name
133+
list_of_labels = new_ngff_image.labels.list()
134+
if label not in list_of_labels:
135+
new_ngff_image.labels._label_group.attrs["labels"] = [
136+
*list_of_labels,
137+
label,
138+
]
139+
140+
if rebuild_pyramids:
141+
# Set the highest resolution, then consolidate to build a new pyramid
142+
new_ngff_image.labels.get_label(
143+
name=label, highest_resolution=True
144+
).set_array(
145+
ngff_image.labels.get_label(
146+
name=label, highest_resolution=True
147+
).on_disk_dask_array
148+
)
149+
new_ngff_image.labels.get_label(
150+
name=label, highest_resolution=True
151+
).consolidate()
152+
else:
153+
for label_path in label_level_paths:
154+
new_ngff_image.labels.get_label(name=label, path=label_path).set_array(
155+
ngff_image.labels.get_label(
156+
name=label, path=label_path
157+
).on_disk_dask_array
158+
)
159+
if overwrite_input:
160+
os.rename(zarr_url, f"{zarr_url}_tmp")
161+
os.rename(rechunked_zarr_url, zarr_url)
162+
shutil.rmtree(f"{zarr_url}_tmp")
163+
return
164+
else:
165+
image_list_updates = dict(
166+
image_list_updates=[dict(zarr_url=rechunked_zarr_url, origin=zarr_url)]
167+
)
168+
return image_list_updates
169+
170+
171+
if __name__ == "__main__":
172+
from fractal_tasks_core.tasks._utils import run_fractal_task
173+
174+
run_fractal_task(
175+
task_function=rechunk_zarr,
176+
logger_name=logger.name,
177+
)

0 commit comments

Comments
 (0)