|
3 | 3 | import logging |
4 | 4 | from typing import Optional |
5 | 5 |
|
6 | | -import anndata as ad |
7 | 6 | import dask.array as da |
8 | | -import numpy as np |
9 | | -import zarr |
10 | | -from fractal_tasks_core.labels import prepare_label_group |
11 | | -from fractal_tasks_core.ngff.zarr_utils import load_NgffImageMeta |
12 | | -from fractal_tasks_core.pyramids import build_pyramid |
13 | | -from fractal_tasks_core.tables import write_table |
| 7 | +import ngio |
| 8 | +from ngio.utils import NgioFileNotFoundError |
14 | 9 | from pydantic import validate_call |
15 | 10 |
|
16 | 11 | logger = logging.getLogger(__name__) |
17 | 12 |
|
18 | 13 |
|
19 | | -def read_table_and_attrs(zarr_url: str, roi_table): |
20 | | - """Read table & attrs from Zarr Anndata tables.""" |
21 | | - table_url = f"{zarr_url}/tables/{roi_table}" |
22 | | - table = ad.read_zarr(table_url) |
23 | | - table_attrs = get_zattrs(table_url) |
24 | | - return table, table_attrs |
25 | | - |
26 | | - |
27 | | -def update_table_metadata(group_tables, table_name): |
28 | | - """Update table metadata.""" |
29 | | - if "tables" not in group_tables.attrs: |
30 | | - group_tables.attrs["tables"] = [table_name] |
31 | | - elif table_name not in group_tables.attrs["tables"]: |
32 | | - group_tables.attrs["tables"] = group_tables.attrs["tables"] + [table_name] |
33 | | - |
34 | | - |
35 | | -def get_zattrs(zarr_url): |
36 | | - """Get zattrs of a Zarr as a dictionary.""" |
37 | | - with zarr.open(zarr_url, mode="r") as zarr_img: |
38 | | - return zarr_img.attrs.asdict() |
39 | | - |
40 | | - |
41 | | -def make_zattrs_3D(attrs, z_pixel_size, new_label_name): |
42 | | - """Creates 3D zattrs based on 2D attrs. |
43 | | -
|
44 | | - Performs the following checks: |
45 | | - 1) If the label image has 2 axes, add a Z axis and updadte the |
46 | | - coordinateTransformations |
47 | | - 2) Change the label name that is referenced, if a new name is provided |
48 | | - """ |
49 | | - if len(attrs["multiscales"][0]["axes"]) == 3: |
50 | | - pass |
51 | | - # If we're getting a 2D image, we need to add a Z axis |
52 | | - elif len(attrs["multiscales"][0]["axes"]) == 2: |
53 | | - z_axis = attrs["multiscales"][0]["axes"][-1] |
54 | | - z_axis["name"] = "z" |
55 | | - attrs["multiscales"][0]["axes"] = [z_axis] + attrs["multiscales"][0]["axes"] |
56 | | - for i, dataset in enumerate(attrs["multiscales"][0]["datasets"]): |
57 | | - if len(dataset["coordinateTransformations"][0]["scale"]) == 2: |
58 | | - attrs["multiscales"][0]["datasets"][i]["coordinateTransformations"][0][ |
59 | | - "scale" |
60 | | - ] = [z_pixel_size] + dataset["coordinateTransformations"][0]["scale"] |
61 | | - else: |
62 | | - raise NotImplementedError( |
63 | | - f"A dataset with 2 axes {attrs['multiscales'][0]['axes']}" |
64 | | - "must have coordinateTransformations with 2 scales. " |
65 | | - "Instead, it had " |
66 | | - f"{dataset['coordinateTransformations'][0]['scale']}" |
67 | | - ) |
68 | | - else: |
69 | | - raise NotImplementedError("The label image must have 2 or 3 axes") |
70 | | - attrs["multiscales"][0]["name"] = new_label_name |
71 | | - return attrs |
72 | | - |
73 | | - |
74 | | -def check_table_validity(new_table_names, old_table_names): |
75 | | - """Validate table mapping between old & new tables.""" |
76 | | - if new_table_names and old_table_names: |
77 | | - if len(new_table_names) != len(old_table_names): |
78 | | - raise ValueError( |
79 | | - "The number of new table names must match the number of old " |
80 | | - f"table names. Instead, the task got {len(new_table_names)}" |
81 | | - "new table names vs. {len(old_table_names)} old table names." |
82 | | - "Check the task configuration, specifically `new_table_names`" |
83 | | - ) |
84 | | - if len(set(new_table_names)) != len(new_table_names): |
85 | | - raise ValueError( |
86 | | - "The new table names must be unique. Instead, the task got " |
87 | | - f"{new_table_names}" |
88 | | - ) |
89 | | - |
90 | | - |
91 | 14 | @validate_call |
92 | 15 | def convert_2D_segmentation_to_3D( |
93 | 16 | zarr_url: str, |
94 | 17 | label_name: str, |
95 | | - level: int = 0, |
96 | | - ROI_tables_to_copy: Optional[list[str]] = None, |
| 18 | + level: str = 0, |
| 19 | + tables_to_copy: Optional[list[str]] = None, |
97 | 20 | new_label_name: Optional[str] = None, |
98 | 21 | new_table_names: Optional[list] = None, |
99 | 22 | plate_suffix: str = "_mip", |
@@ -121,11 +44,11 @@ def convert_2D_segmentation_to_3D( |
121 | 44 | (standard argument for Fractal tasks, managed by Fractal server). |
122 | 45 | label_name: Name of the label to copy from 2D OME-Zarr to |
123 | 46 | 3D OME-Zarr |
124 | | - ROI_tables_to_copy: List of ROI table names to copy from 2D OME-Zarr |
| 47 | + tables_to_copy: List of tables to copy from 2D OME-Zarr |
125 | 48 | to 3D OME-Zarr |
126 | 49 | new_label_name: Optionally overwriting the name of the label in |
127 | 50 | the 3D OME-Zarr |
128 | | - new_table_names: Optionally overwriting the names of the ROI tables |
| 51 | + new_table_names: Optionally overwriting the names of the tables |
129 | 52 | in the 3D OME-Zarr |
130 | 53 | level: Level of the 2D OME-Zarr label to copy from |
131 | 54 | plate_suffix: Suffix of the 2D OME-Zarr that needs to be removed to |
@@ -154,114 +77,129 @@ def convert_2D_segmentation_to_3D( |
154 | 77 | # 0) Preparation |
155 | 78 | if level != 0: |
156 | 79 | raise NotImplementedError("Only level 0 is supported at the moment") |
| 80 | + if new_table_names: |
| 81 | + if not tables_to_copy: |
| 82 | + raise ValueError( |
| 83 | + "If new_table_names is set, tables_to_copy must also be set." |
| 84 | + ) |
| 85 | + if len(new_table_names) != len(tables_to_copy): |
| 86 | + raise ValueError( |
| 87 | + "If new_table_names is set, it must have the same number of " |
| 88 | + f"entries as tables_to_copy. They were: {new_table_names=}" |
| 89 | + f"and {tables_to_copy=}" |
| 90 | + ) |
| 91 | + |
157 | 92 | zarr_3D_url = zarr_url.replace(f"{plate_suffix}.zarr", ".zarr") |
158 | 93 | # Handle changes to image name |
159 | | - # (would get easier if projections were subgroups!) |
160 | 94 | if image_suffix_2D_to_remove: |
161 | 95 | zarr_3D_url = zarr_3D_url.rstrip(image_suffix_2D_to_remove) |
162 | 96 | if image_suffix_3D_to_add: |
163 | 97 | zarr_3D_url += image_suffix_3D_to_add |
164 | 98 |
|
165 | | - # FIXME: Check whether 3D Zarr actually exists |
166 | | - |
167 | 99 | if new_label_name is None: |
168 | 100 | new_label_name = label_name |
169 | 101 | if new_table_names is None: |
170 | | - new_table_names = ROI_tables_to_copy |
| 102 | + new_table_names = tables_to_copy |
| 103 | + |
| 104 | + try: |
| 105 | + ome_zarr_container_3d = ngio.open_ome_zarr_container(zarr_3D_url) |
| 106 | + except NgioFileNotFoundError as e: |
| 107 | + raise ValueError( |
| 108 | + f"3D OME-Zarr {zarr_3D_url} not found. Please check the " |
| 109 | + f"suffix (set to {plate_suffix})." |
| 110 | + ) from e |
171 | 111 |
|
172 | | - check_table_validity(new_table_names, ROI_tables_to_copy) |
173 | 112 | logger.info( |
174 | 113 | f"Copying {label_name} from {zarr_url} to {zarr_3D_url} as " |
175 | 114 | f"{new_label_name}." |
176 | 115 | ) |
177 | 116 |
|
178 | | - # 1a) Load a 2D label image |
179 | | - label_img = da.from_zarr(f"{zarr_url}/labels/{label_name}/{level}") |
180 | | - chunks = list(label_img.chunksize) |
| 117 | + # 1) Load a 2D label image |
| 118 | + ome_zarr_container_2d = ngio.open_ome_zarr_container(zarr_url) |
| 119 | + label_img = ome_zarr_container_2d.get_label(label_name, path=str(level)) |
| 120 | + |
| 121 | + if not label_img.is_2d: |
| 122 | + raise ValueError( |
| 123 | + f"Label image {label_name} is not 2D. It has a shape of " |
| 124 | + f"{label_img.shape} and the axes " |
| 125 | + f"{label_img.axes_mapper.on_disk_axes_names}." |
| 126 | + ) |
181 | 127 |
|
182 | | - # 1b) Get number z planes & Z spacing from 3D OME-Zarr file |
183 | | - with zarr.open(zarr_3D_url, mode="rw+") as zarr_img: |
184 | | - zarr_3D = da.from_zarr(zarr_img[0]) |
185 | | - new_z_planes = zarr_3D.shape[-3] |
186 | | - z_chunk_3d = zarr_3D.chunksize[-3] |
| 128 | + chunks = list(label_img.chunks) |
| 129 | + label_dask = label_img.get_array(mode="dask") |
187 | 130 |
|
188 | | - # TODO: Improve axis detection in ngio refactor? |
| 131 | + # 2) Set up the 3D label image |
| 132 | + ref_image_3d = ome_zarr_container_3d.get_image( |
| 133 | + pixel_size=label_img.pixel_size, |
| 134 | + ) |
| 135 | + |
| 136 | + z_index = label_img.axes_mapper.get_index("z") |
| 137 | + y_index = label_img.axes_mapper.get_index("y") |
| 138 | + x_index = label_img.axes_mapper.get_index("x") |
| 139 | + z_index_3d_reference = ref_image_3d.axes_mapper.get_index("z") |
189 | 140 | if z_chunks: |
190 | | - chunks[-3] = z_chunks |
| 141 | + chunks[z_index] = z_chunks |
191 | 142 | else: |
192 | | - chunks[-3] = z_chunk_3d |
| 143 | + chunks[z_index] = ref_image_3d.chunks[z_index_3d_reference] |
193 | 144 | chunks = tuple(chunks) |
194 | 145 |
|
195 | | - image_meta = load_NgffImageMeta(zarr_3D_url) |
196 | | - z_pixel_size = image_meta.get_pixel_sizes_zyx(level=0)[0] |
| 146 | + nb_z_planes = ref_image_3d.shape[z_index_3d_reference] |
197 | 147 |
|
198 | | - # Prepare the output label group |
199 | | - # Get the label_attrs correctly (removes hack below) |
200 | | - label_attrs = get_zattrs(zarr_url=f"{zarr_url}/labels/{label_name}") |
201 | | - label_attrs = make_zattrs_3D(label_attrs, z_pixel_size, new_label_name) |
202 | | - output_label_group = prepare_label_group( |
203 | | - image_group=zarr.group(zarr_3D_url), |
204 | | - label_name=new_label_name, |
205 | | - overwrite=overwrite, |
206 | | - label_attrs=label_attrs, |
207 | | - logger=logger, |
208 | | - ) |
| 148 | + shape_3d = (nb_z_planes, label_img.shape[y_index], label_img.shape[x_index]) |
209 | 149 |
|
210 | | - logger.info(f"Helper function `prepare_label_group` returned {output_label_group=}") |
| 150 | + pixel_size = label_img.pixel_size |
| 151 | + pixel_size.z = ref_image_3d.pixel_size.z |
| 152 | + axes_names = label_img.axes_mapper.on_disk_axes_names |
211 | 153 |
|
212 | | - # 2) Create the 3D stack of the label image |
213 | | - label_img_3D = da.stack([label_img.squeeze()] * new_z_planes) |
| 154 | + z_extent = nb_z_planes * pixel_size.z |
214 | 155 |
|
215 | | - # 3) Save changed label image to OME-Zarr |
216 | | - label_dtype = np.uint32 |
217 | | - store = zarr.storage.FSStore(f"{zarr_3D_url}/labels/{new_label_name}/0") |
218 | | - new_label_array = zarr.create( |
219 | | - shape=label_img_3D.shape, |
| 156 | + new_label_container = ome_zarr_container_3d.derive_label( |
| 157 | + name=new_label_name, |
| 158 | + ref_image=ref_image_3d, |
| 159 | + shape=shape_3d, |
| 160 | + pixel_size=pixel_size, |
| 161 | + axes_names=axes_names, |
220 | 162 | chunks=chunks, |
221 | | - dtype=label_dtype, |
222 | | - store=store, |
223 | | - overwrite=False, |
224 | | - dimension_separator="/", |
| 163 | + dtype=label_img.dtype, |
| 164 | + overwrite=overwrite, |
225 | 165 | ) |
226 | 166 |
|
227 | | - da.array(label_img_3D).to_zarr( |
228 | | - url=new_label_array, |
229 | | - ) |
| 167 | + # 3) Create the 3D stack of the label image |
| 168 | + label_img_3D = da.stack([label_dask.squeeze()] * nb_z_planes) |
| 169 | + |
| 170 | + # 4) Save changed label image to OME-Zarr |
| 171 | + new_label_container.set_array(label_img_3D, axes_order="zyx") |
| 172 | + |
230 | 173 | logger.info(f"Saved {new_label_name} to 3D Zarr at full resolution") |
231 | | - # 4) Build pyramids for label image |
232 | | - label_meta = load_NgffImageMeta(f"{zarr_url}/labels/{label_name}") |
233 | | - build_pyramid( |
234 | | - zarrurl=f"{zarr_3D_url}/labels/{new_label_name}", |
235 | | - overwrite=overwrite, |
236 | | - num_levels=label_meta.num_levels, |
237 | | - coarsening_xy=label_meta.coarsening_xy, |
238 | | - chunksize=chunks, |
239 | | - aggregation_function=np.max, |
240 | | - ) |
| 174 | + # 5) Build pyramids for label image |
| 175 | + new_label_container.consolidate() |
241 | 176 | logger.info(f"Built a pyramid for the {new_label_name} label image") |
242 | 177 |
|
243 | | - # 5) Copy ROI tables |
244 | | - image_group = zarr.group(zarr_3D_url) |
245 | | - if ROI_tables_to_copy: |
246 | | - for i, ROI_table in enumerate(ROI_tables_to_copy): |
247 | | - new_table_name = new_table_names[i] |
248 | | - logger.info(f"Copying ROI table {ROI_table} as {new_table_name}") |
249 | | - roi_an, table_attrs = read_table_and_attrs(zarr_url, ROI_table) |
250 | | - nb_rois = len(roi_an.X) |
251 | | - # Set the new Z values to span the whole ROI |
252 | | - roi_an.X[:, 5] = np.array([z_pixel_size * new_z_planes] * nb_rois) |
253 | | - |
254 | | - write_table( |
255 | | - image_group=image_group, |
256 | | - table_name=new_table_name, |
257 | | - table=roi_an, |
258 | | - overwrite=overwrite, |
259 | | - table_attrs=table_attrs, |
| 178 | + # 6) Copy tables |
| 179 | + if tables_to_copy: |
| 180 | + for i, table_name in enumerate(tables_to_copy): |
| 181 | + if table_name not in ome_zarr_container_2d.list_tables(): |
| 182 | + raise ValueError( |
| 183 | + f"Table {table_name} not found in 2D OME-Zarr {zarr_url}." |
| 184 | + ) |
| 185 | + table = ome_zarr_container_2d.get_table(table_name) |
| 186 | + if table.type() == "roi_table" or table.type() == "masking_ROI_table": |
| 187 | + for roi in table.rois(): |
| 188 | + roi.z_length = z_extent |
| 189 | + |
| 190 | + else: |
| 191 | + # For some reason, I need to load the table explicitly before |
| 192 | + # I can write it again |
| 193 | + # FIXME: Check with Lorenzo why this is |
| 194 | + table.dataframe # noqa #B018 |
| 195 | + ome_zarr_container_3d.add_table( |
| 196 | + name=new_table_names[i], table=table, overwrite=False |
260 | 197 | ) |
261 | 198 |
|
262 | 199 | logger.info("Finished 2D to 3D conversion") |
263 | 200 |
|
264 | 201 | # Give the 3D image as an output so that filters are applied correctly |
| 202 | + # (because manifest type filters get applied to the output image) |
265 | 203 | image_list_updates = dict( |
266 | 204 | image_list_updates=[ |
267 | 205 | dict( |
|
0 commit comments