|
10 | 10 | from typing import Any, Optional |
11 | 11 |
|
12 | 12 | import ngio |
| 13 | +from ngio.ome_zarr_meta import AxesMapper |
13 | 14 | from pydantic import validate_call |
14 | 15 |
|
15 | | -from fractal_helper_tasks.utils import normalize_chunk_size_dict, rechunk_label |
| 16 | +from fractal_helper_tasks.utils import normalize_chunk_size_dict |
16 | 17 |
|
17 | 18 | logger = logging.getLogger(__name__) |
18 | 19 |
|
19 | 20 |
|
| 21 | +def change_chunks( |
| 22 | + initial_chunks: list[int], |
| 23 | + axes_mapper: AxesMapper, |
| 24 | + chunk_sizes: dict[str, Optional[int]], |
| 25 | +) -> list[int]: |
| 26 | + """Create a new chunk_size list with rechunking. |
| 27 | +
|
| 28 | + Based on the initial chunks, the axes_mapper of the OME-Zarr & the |
| 29 | + chunk_sizes dictionary with new chunk sizes, create a new chunk_size list. |
| 30 | +
|
| 31 | + """ |
| 32 | + for axes_name, chunk_value in chunk_sizes.items(): |
| 33 | + if chunk_value is not None: |
| 34 | + axes_index = axes_mapper.get_index(axes_name) |
| 35 | + if axes_index is None: |
| 36 | + raise ValueError( |
| 37 | + f"Rechunking with {axes_name=} is specified, but the " |
| 38 | + "OME-Zarr only has the following axes: " |
| 39 | + f"{axes_mapper.on_disk_axes_names}" |
| 40 | + ) |
| 41 | + initial_chunks[axes_index] = chunk_value |
| 42 | + return initial_chunks |
| 43 | + |
| 44 | + |
20 | 45 | @validate_call |
21 | 46 | def rechunk_zarr( |
22 | 47 | *, |
@@ -56,73 +81,75 @@ def rechunk_zarr( |
56 | 81 | chunk_sizes = normalize_chunk_size_dict(chunk_sizes) |
57 | 82 |
|
58 | 83 | rechunked_zarr_url = zarr_url + f"_{suffix}" |
59 | | - ngff_image = ngio.NgffImage(zarr_url) |
60 | | - pyramid_paths = ngff_image.levels_paths |
61 | | - highest_res_img = ngff_image.get_image() |
62 | | - axes_names = highest_res_img.dataset.on_disk_axes_names |
63 | | - chunks = highest_res_img.on_disk_dask_array.chunks |
64 | | - |
65 | | - # Compute the chunksize tuple |
66 | | - new_chunksize = [c[0] for c in chunks] |
67 | | - logger.info(f"Initial chunk sizes were: {chunks}") |
68 | | - # Overwrite chunk_size with user-set chunksize |
69 | | - for i, axis in enumerate(axes_names): |
70 | | - if axis in chunk_sizes: |
71 | | - if chunk_sizes[axis] is not None: |
72 | | - new_chunksize[i] = chunk_sizes[axis] |
73 | | - |
74 | | - for axis in chunk_sizes: |
75 | | - if axis not in axes_names: |
76 | | - raise NotImplementedError( |
77 | | - f"Rechunking with {axis=} is specified, but the OME-Zarr only " |
78 | | - f"has the following axes: {axes_names}" |
79 | | - ) |
| 84 | + ome_zarr_container = ngio.open_ome_zarr_container(zarr_url) |
| 85 | + pyramid_paths = ome_zarr_container.levels_paths |
| 86 | + highest_res_img = ome_zarr_container.get_image() |
| 87 | + chunks = highest_res_img.chunks |
| 88 | + new_chunksize = change_chunks( |
| 89 | + initial_chunks=list(chunks), |
| 90 | + axes_mapper=highest_res_img.meta.axes_mapper, |
| 91 | + chunk_sizes=chunk_sizes, |
| 92 | + ) |
80 | 93 |
|
81 | 94 | logger.info(f"Chunk sizes after rechunking will be: {new_chunksize=}") |
82 | 95 |
|
83 | | - new_ngff_image = ngff_image.derive_new_image( |
| 96 | + new_ome_zarr_container = ome_zarr_container.derive_image( |
84 | 97 | store=rechunked_zarr_url, |
85 | | - name=ngff_image.image_meta.name, |
| 98 | + name=ome_zarr_container.image_meta.name, |
86 | 99 | overwrite=overwrite, |
87 | 100 | copy_labels=not rechunk_labels, |
88 | 101 | copy_tables=True, |
89 | 102 | chunks=new_chunksize, |
90 | 103 | ) |
91 | 104 |
|
92 | | - ngff_image = ngio.NgffImage(zarr_url) |
93 | | - |
94 | 105 | if rebuild_pyramids: |
95 | 106 | # Set the highest resolution, then consolidate to build a new pyramid |
96 | | - new_ngff_image.get_image(highest_resolution=True).set_array( |
97 | | - ngff_image.get_image(highest_resolution=True).on_disk_dask_array |
98 | | - ) |
99 | | - new_ngff_image.get_image(highest_resolution=True).consolidate() |
| 107 | + new_image = new_ome_zarr_container.get_image() |
| 108 | + new_image.set_array(ome_zarr_container.get_image().get_array(mode="dask")) |
| 109 | + new_image.consolidate() |
100 | 110 | else: |
101 | 111 | for path in pyramid_paths: |
102 | | - new_ngff_image.get_image(path=path).set_array( |
103 | | - ngff_image.get_image(path=path).on_disk_dask_array |
| 112 | + new_ome_zarr_container.get_image(path=path).set_array( |
| 113 | + ome_zarr_container.get_image(path=path).get_array(mode="dask") |
104 | 114 | ) |
105 | 115 |
|
106 | | - # Copy labels |
| 116 | + # Rechunk labels |
107 | 117 | if rechunk_labels: |
108 | 118 | chunk_sizes["c"] = None |
109 | | - label_names = ngff_image.labels.list() |
| 119 | + label_names = ome_zarr_container.list_labels() |
110 | 120 | for label in label_names: |
111 | | - rechunk_label( |
112 | | - orig_ngff_image=ngff_image, |
113 | | - new_ngff_image=new_ngff_image, |
114 | | - label=label, |
| 121 | + old_label = ome_zarr_container.get_label(name=label) |
| 122 | + new_chunksize = change_chunks( |
| 123 | + initial_chunks=list(old_label.chunks), |
| 124 | + axes_mapper=old_label.meta.axes_mapper, |
115 | 125 | chunk_sizes=chunk_sizes, |
| 126 | + ) |
| 127 | + ngio.images.label._derive_label( |
| 128 | + name=label, |
| 129 | + store=f"{rechunked_zarr_url}/labels/{label}", |
| 130 | + ref_image=old_label, |
| 131 | + chunks=new_chunksize, |
116 | 132 | overwrite=overwrite, |
117 | | - rebuild_pyramids=rebuild_pyramids, |
118 | 133 | ) |
| 134 | + if rebuild_pyramids: |
| 135 | + new_label = new_ome_zarr_container.get_label(name=label) |
| 136 | + new_label.set_array(old_label.get_array(mode="dask")) |
| 137 | + new_label.consolidate() |
| 138 | + else: |
| 139 | + label_pyramid_paths = old_label.meta.paths |
| 140 | + for path in label_pyramid_paths: |
| 141 | + new_ome_zarr_container.get_label(name=label, path=path).set_array( |
| 142 | + old_label.get_array(path=path, mode="dask") |
| 143 | + ) |
119 | 144 |
|
120 | 145 | if overwrite_input: |
121 | 146 | os.rename(zarr_url, f"{zarr_url}_tmp") |
122 | 147 | os.rename(rechunked_zarr_url, zarr_url) |
123 | 148 | shutil.rmtree(f"{zarr_url}_tmp") |
124 | 149 | return |
125 | 150 | else: |
| 151 | + # FIXME: Update well metadata to add the new image if the image is in |
| 152 | + # a well |
126 | 153 | output = dict( |
127 | 154 | image_list_updates=[ |
128 | 155 | dict( |
|
0 commit comments