Skip to content

Commit fcb18c0

Browse files
committed
Refactor convert_2d_segmentation_to_3d task & update to ngio 0.2.2
1 parent 5afe4fb commit fcb18c0

File tree

2 files changed

+203
-167
lines changed

2 files changed

+203
-167
lines changed

src/fractal_helper_tasks/convert_2D_segmentation_to_3D.py

Lines changed: 92 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -3,97 +3,20 @@
33
import logging
44
from typing import Optional
55

6-
import anndata as ad
76
import dask.array as da
8-
import numpy as np
9-
import zarr
10-
from fractal_tasks_core.labels import prepare_label_group
11-
from fractal_tasks_core.ngff.zarr_utils import load_NgffImageMeta
12-
from fractal_tasks_core.pyramids import build_pyramid
13-
from fractal_tasks_core.tables import write_table
7+
import ngio
8+
from ngio.utils import NgioFileNotFoundError
149
from pydantic import validate_call
1510

1611
logger = logging.getLogger(__name__)
1712

1813

19-
def read_table_and_attrs(zarr_url: str, roi_table):
20-
"""Read table & attrs from Zarr Anndata tables."""
21-
table_url = f"{zarr_url}/tables/{roi_table}"
22-
table = ad.read_zarr(table_url)
23-
table_attrs = get_zattrs(table_url)
24-
return table, table_attrs
25-
26-
27-
def update_table_metadata(group_tables, table_name):
28-
"""Update table metadata."""
29-
if "tables" not in group_tables.attrs:
30-
group_tables.attrs["tables"] = [table_name]
31-
elif table_name not in group_tables.attrs["tables"]:
32-
group_tables.attrs["tables"] = group_tables.attrs["tables"] + [table_name]
33-
34-
35-
def get_zattrs(zarr_url):
36-
"""Get zattrs of a Zarr as a dictionary."""
37-
with zarr.open(zarr_url, mode="r") as zarr_img:
38-
return zarr_img.attrs.asdict()
39-
40-
41-
def make_zattrs_3D(attrs, z_pixel_size, new_label_name):
42-
"""Creates 3D zattrs based on 2D attrs.
43-
44-
Performs the following checks:
45-
1) If the label image has 2 axes, add a Z axis and updadte the
46-
coordinateTransformations
47-
2) Change the label name that is referenced, if a new name is provided
48-
"""
49-
if len(attrs["multiscales"][0]["axes"]) == 3:
50-
pass
51-
# If we're getting a 2D image, we need to add a Z axis
52-
elif len(attrs["multiscales"][0]["axes"]) == 2:
53-
z_axis = attrs["multiscales"][0]["axes"][-1]
54-
z_axis["name"] = "z"
55-
attrs["multiscales"][0]["axes"] = [z_axis] + attrs["multiscales"][0]["axes"]
56-
for i, dataset in enumerate(attrs["multiscales"][0]["datasets"]):
57-
if len(dataset["coordinateTransformations"][0]["scale"]) == 2:
58-
attrs["multiscales"][0]["datasets"][i]["coordinateTransformations"][0][
59-
"scale"
60-
] = [z_pixel_size] + dataset["coordinateTransformations"][0]["scale"]
61-
else:
62-
raise NotImplementedError(
63-
f"A dataset with 2 axes {attrs['multiscales'][0]['axes']}"
64-
"must have coordinateTransformations with 2 scales. "
65-
"Instead, it had "
66-
f"{dataset['coordinateTransformations'][0]['scale']}"
67-
)
68-
else:
69-
raise NotImplementedError("The label image must have 2 or 3 axes")
70-
attrs["multiscales"][0]["name"] = new_label_name
71-
return attrs
72-
73-
74-
def check_table_validity(new_table_names, old_table_names):
75-
"""Validate table mapping between old & new tables."""
76-
if new_table_names and old_table_names:
77-
if len(new_table_names) != len(old_table_names):
78-
raise ValueError(
79-
"The number of new table names must match the number of old "
80-
f"table names. Instead, the task got {len(new_table_names)}"
81-
"new table names vs. {len(old_table_names)} old table names."
82-
"Check the task configuration, specifically `new_table_names`"
83-
)
84-
if len(set(new_table_names)) != len(new_table_names):
85-
raise ValueError(
86-
"The new table names must be unique. Instead, the task got "
87-
f"{new_table_names}"
88-
)
89-
90-
9114
@validate_call
9215
def convert_2D_segmentation_to_3D(
9316
zarr_url: str,
9417
label_name: str,
95-
level: int = 0,
96-
ROI_tables_to_copy: Optional[list[str]] = None,
18+
level: str = 0,
19+
tables_to_copy: Optional[list[str]] = None,
9720
new_label_name: Optional[str] = None,
9821
new_table_names: Optional[list] = None,
9922
plate_suffix: str = "_mip",
@@ -121,11 +44,11 @@ def convert_2D_segmentation_to_3D(
12144
(standard argument for Fractal tasks, managed by Fractal server).
12245
label_name: Name of the label to copy from 2D OME-Zarr to
12346
3D OME-Zarr
124-
ROI_tables_to_copy: List of ROI table names to copy from 2D OME-Zarr
47+
tables_to_copy: List of tables to copy from 2D OME-Zarr
12548
to 3D OME-Zarr
12649
new_label_name: Optionally overwriting the name of the label in
12750
the 3D OME-Zarr
128-
new_table_names: Optionally overwriting the names of the ROI tables
51+
new_table_names: Optionally overwriting the names of the tables
12952
in the 3D OME-Zarr
13053
level: Level of the 2D OME-Zarr label to copy from
13154
plate_suffix: Suffix of the 2D OME-Zarr that needs to be removed to
@@ -154,114 +77,129 @@ def convert_2D_segmentation_to_3D(
15477
# 0) Preparation
15578
if level != 0:
15679
raise NotImplementedError("Only level 0 is supported at the moment")
80+
if new_table_names:
81+
if not tables_to_copy:
82+
raise ValueError(
83+
"If new_table_names is set, tables_to_copy must also be set."
84+
)
85+
if len(new_table_names) != len(tables_to_copy):
86+
raise ValueError(
87+
"If new_table_names is set, it must have the same number of "
88+
f"entries as tables_to_copy. They were: {new_table_names=}"
89+
f"and {tables_to_copy=}"
90+
)
91+
15792
zarr_3D_url = zarr_url.replace(f"{plate_suffix}.zarr", ".zarr")
15893
# Handle changes to image name
159-
# (would get easier if projections were subgroups!)
16094
if image_suffix_2D_to_remove:
16195
zarr_3D_url = zarr_3D_url.rstrip(image_suffix_2D_to_remove)
16296
if image_suffix_3D_to_add:
16397
zarr_3D_url += image_suffix_3D_to_add
16498

165-
# FIXME: Check whether 3D Zarr actually exists
166-
16799
if new_label_name is None:
168100
new_label_name = label_name
169101
if new_table_names is None:
170-
new_table_names = ROI_tables_to_copy
102+
new_table_names = tables_to_copy
103+
104+
try:
105+
ome_zarr_container_3d = ngio.open_ome_zarr_container(zarr_3D_url)
106+
except NgioFileNotFoundError as e:
107+
raise ValueError(
108+
f"3D OME-Zarr {zarr_3D_url} not found. Please check the "
109+
f"suffix (set to {plate_suffix})."
110+
) from e
171111

172-
check_table_validity(new_table_names, ROI_tables_to_copy)
173112
logger.info(
174113
f"Copying {label_name} from {zarr_url} to {zarr_3D_url} as "
175114
f"{new_label_name}."
176115
)
177116

178-
# 1a) Load a 2D label image
179-
label_img = da.from_zarr(f"{zarr_url}/labels/{label_name}/{level}")
180-
chunks = list(label_img.chunksize)
117+
# 1) Load a 2D label image
118+
ome_zarr_container_2d = ngio.open_ome_zarr_container(zarr_url)
119+
label_img = ome_zarr_container_2d.get_label(label_name, path=str(level))
120+
121+
if not label_img.is_2d:
122+
raise ValueError(
123+
f"Label image {label_name} is not 2D. It has a shape of "
124+
f"{label_img.shape} and the axes "
125+
f"{label_img.axes_mapper.on_disk_axes_names}."
126+
)
181127

182-
# 1b) Get number z planes & Z spacing from 3D OME-Zarr file
183-
with zarr.open(zarr_3D_url, mode="rw+") as zarr_img:
184-
zarr_3D = da.from_zarr(zarr_img[0])
185-
new_z_planes = zarr_3D.shape[-3]
186-
z_chunk_3d = zarr_3D.chunksize[-3]
128+
chunks = list(label_img.chunks)
129+
label_dask = label_img.get_array(mode="dask")
187130

188-
# TODO: Improve axis detection in ngio refactor?
131+
# 2) Set up the 3D label image
132+
ref_image_3d = ome_zarr_container_3d.get_image(
133+
pixel_size=label_img.pixel_size,
134+
)
135+
136+
z_index = label_img.axes_mapper.get_index("z")
137+
y_index = label_img.axes_mapper.get_index("y")
138+
x_index = label_img.axes_mapper.get_index("x")
139+
z_index_3d_reference = ref_image_3d.axes_mapper.get_index("z")
189140
if z_chunks:
190-
chunks[-3] = z_chunks
141+
chunks[z_index] = z_chunks
191142
else:
192-
chunks[-3] = z_chunk_3d
143+
chunks[z_index] = ref_image_3d.chunks[z_index_3d_reference]
193144
chunks = tuple(chunks)
194145

195-
image_meta = load_NgffImageMeta(zarr_3D_url)
196-
z_pixel_size = image_meta.get_pixel_sizes_zyx(level=0)[0]
146+
nb_z_planes = ref_image_3d.shape[z_index_3d_reference]
197147

198-
# Prepare the output label group
199-
# Get the label_attrs correctly (removes hack below)
200-
label_attrs = get_zattrs(zarr_url=f"{zarr_url}/labels/{label_name}")
201-
label_attrs = make_zattrs_3D(label_attrs, z_pixel_size, new_label_name)
202-
output_label_group = prepare_label_group(
203-
image_group=zarr.group(zarr_3D_url),
204-
label_name=new_label_name,
205-
overwrite=overwrite,
206-
label_attrs=label_attrs,
207-
logger=logger,
208-
)
148+
shape_3d = (nb_z_planes, label_img.shape[y_index], label_img.shape[x_index])
209149

210-
logger.info(f"Helper function `prepare_label_group` returned {output_label_group=}")
150+
pixel_size = label_img.pixel_size
151+
pixel_size.z = ref_image_3d.pixel_size.z
152+
axes_names = label_img.axes_mapper.on_disk_axes_names
211153

212-
# 2) Create the 3D stack of the label image
213-
label_img_3D = da.stack([label_img.squeeze()] * new_z_planes)
154+
z_extent = nb_z_planes * pixel_size.z
214155

215-
# 3) Save changed label image to OME-Zarr
216-
label_dtype = np.uint32
217-
store = zarr.storage.FSStore(f"{zarr_3D_url}/labels/{new_label_name}/0")
218-
new_label_array = zarr.create(
219-
shape=label_img_3D.shape,
156+
new_label_container = ome_zarr_container_3d.derive_label(
157+
name=new_label_name,
158+
ref_image=ref_image_3d,
159+
shape=shape_3d,
160+
pixel_size=pixel_size,
161+
axes_names=axes_names,
220162
chunks=chunks,
221-
dtype=label_dtype,
222-
store=store,
223-
overwrite=False,
224-
dimension_separator="/",
163+
dtype=label_img.dtype,
164+
overwrite=overwrite,
225165
)
226166

227-
da.array(label_img_3D).to_zarr(
228-
url=new_label_array,
229-
)
167+
# 3) Create the 3D stack of the label image
168+
label_img_3D = da.stack([label_dask.squeeze()] * nb_z_planes)
169+
170+
# 4) Save changed label image to OME-Zarr
171+
new_label_container.set_array(label_img_3D, axes_order="zyx")
172+
230173
logger.info(f"Saved {new_label_name} to 3D Zarr at full resolution")
231-
# 4) Build pyramids for label image
232-
label_meta = load_NgffImageMeta(f"{zarr_url}/labels/{label_name}")
233-
build_pyramid(
234-
zarrurl=f"{zarr_3D_url}/labels/{new_label_name}",
235-
overwrite=overwrite,
236-
num_levels=label_meta.num_levels,
237-
coarsening_xy=label_meta.coarsening_xy,
238-
chunksize=chunks,
239-
aggregation_function=np.max,
240-
)
174+
# 5) Build pyramids for label image
175+
new_label_container.consolidate()
241176
logger.info(f"Built a pyramid for the {new_label_name} label image")
242177

243-
# 5) Copy ROI tables
244-
image_group = zarr.group(zarr_3D_url)
245-
if ROI_tables_to_copy:
246-
for i, ROI_table in enumerate(ROI_tables_to_copy):
247-
new_table_name = new_table_names[i]
248-
logger.info(f"Copying ROI table {ROI_table} as {new_table_name}")
249-
roi_an, table_attrs = read_table_and_attrs(zarr_url, ROI_table)
250-
nb_rois = len(roi_an.X)
251-
# Set the new Z values to span the whole ROI
252-
roi_an.X[:, 5] = np.array([z_pixel_size * new_z_planes] * nb_rois)
253-
254-
write_table(
255-
image_group=image_group,
256-
table_name=new_table_name,
257-
table=roi_an,
258-
overwrite=overwrite,
259-
table_attrs=table_attrs,
178+
# 6) Copy tables
179+
if tables_to_copy:
180+
for i, table_name in enumerate(tables_to_copy):
181+
if table_name not in ome_zarr_container_2d.list_tables():
182+
raise ValueError(
183+
f"Table {table_name} not found in 2D OME-Zarr {zarr_url}."
184+
)
185+
table = ome_zarr_container_2d.get_table(table_name)
186+
if table.type() == "roi_table" or table.type() == "masking_ROI_table":
187+
for roi in table.rois():
188+
roi.z_length = z_extent
189+
190+
else:
191+
# For some reason, I need to load the table explicitly before
192+
# I can write it again
193+
# FIXME: Check with Lorenzo why this is
194+
table.dataframe # noqa #B018
195+
ome_zarr_container_3d.add_table(
196+
name=new_table_names[i], table=table, overwrite=False
260197
)
261198

262199
logger.info("Finished 2D to 3D conversion")
263200

264201
# Give the 3D image as an output so that filters are applied correctly
202+
# (because manifest type filters get applied to the output image)
265203
image_list_updates = dict(
266204
image_list_updates=[
267205
dict(

0 commit comments

Comments
 (0)