Skip to content

Commit 605d469

Browse files
ngio based re-implementation of projection task
1 parent e55821e commit 605d469

File tree

2 files changed

+78
-126
lines changed

2 files changed

+78
-126
lines changed

fractal_tasks_core/tasks/projection.py

Lines changed: 76 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,36 @@
1212
"""
1313
Task for 3D->2D maximum-intensity projection.
1414
"""
15-
import logging
16-
from typing import Any
15+
from __future__ import annotations
16+
from typing import Any, TYPE_CHECKING
1717

18-
import anndata as ad
1918
import dask.array as da
20-
import zarr
19+
from ngio import NgffImage
20+
2121
from pydantic import validate_call
22-
from zarr.errors import ContainsArrayError
23-
24-
from fractal_tasks_core.ngff import load_NgffImageMeta
25-
from fractal_tasks_core.pyramids import build_pyramid
26-
from fractal_tasks_core.roi import (
27-
convert_ROIs_from_3D_to_2D,
28-
)
29-
from fractal_tasks_core.tables import write_table
30-
from fractal_tasks_core.tables.v1 import get_tables_list_v1
22+
3123
from fractal_tasks_core.tasks.io_models import InitArgsMIP
3224
from fractal_tasks_core.tasks.projection_utils import DaskProjectionMethod
33-
from fractal_tasks_core.zarr_utils import OverwriteNotAllowedError
25+
from ngio.utils import ngio_logger
26+
27+
if TYPE_CHECKING:
28+
from ngio.core import Image
3429

3530

36-
logger = logging.getLogger(__name__)
31+
def _compute_new_shape(source_image: Image) -> tuple[int]:
32+
""" Compute the new shape of the image after the projection.
33+
34+
The new shape is the same as the original one, except for the z-axis, which is set to 1.
35+
"""
36+
on_disk_shape = source_image.on_disk_shape
37+
ngio_logger.info(f"Source {on_disk_shape=}")
38+
39+
on_disk_z_index = source_image.dataset.on_disk_axes_names.index("z")
40+
41+
dest_on_disk_shape = list(on_disk_shape)
42+
dest_on_disk_shape[on_disk_z_index] = 1
43+
ngio_logger.info(f"Destination {dest_on_disk_shape=}")
44+
return tuple(dest_on_disk_shape)
3745

3846

3947
@validate_call
@@ -55,121 +63,66 @@ def projection(
5563
`create_cellvoyager_ome_zarr_init`.
5664
"""
5765
method = DaskProjectionMethod(init_args.method)
58-
logger.info(f"{init_args.origin_url=}")
59-
logger.info(f"{zarr_url=}")
60-
logger.info(f"{method=}")
66+
ngio_logger.info(f"{init_args.origin_url=}")
67+
ngio_logger.info(f"{zarr_url=}")
68+
ngio_logger.info(f"{method=}")
6169

6270
# Read image metadata
63-
ngff_image = load_NgffImageMeta(init_args.origin_url)
64-
# Currently not using the validation models due to wavelength_id issue
65-
# See #681 for discussion
66-
# new_attrs = ngff_image.model_dump(exclude_none=True)
67-
# Current way to get the necessary metadata for MIP
68-
group = zarr.open_group(init_args.origin_url, mode="r")
69-
new_attrs = group.attrs.asdict()
70-
71-
# Create the zarr image with correct
72-
new_image_group = zarr.group(zarr_url)
73-
new_image_group.attrs.put(new_attrs)
74-
75-
# Load 0-th level
76-
data_czyx = da.from_zarr(init_args.origin_url + "/0")
77-
num_channels = data_czyx.shape[0]
78-
chunksize_y = data_czyx.chunksize[-2]
79-
chunksize_x = data_czyx.chunksize[-1]
80-
logger.info(f"{num_channels=}")
81-
logger.info(f"{chunksize_y=}")
82-
logger.info(f"{chunksize_x=}")
83-
84-
# Loop over channels
85-
accumulate_chl = []
86-
for ind_ch in range(num_channels):
87-
# Perform MIP for each channel of level 0
88-
project_yx = da.stack(
89-
[method.apply(data_czyx[ind_ch], axis=0)], axis=0
90-
)
91-
accumulate_chl.append(project_yx)
92-
accumulated_array = da.stack(accumulate_chl, axis=0)
93-
94-
# Write to disk (triggering execution)
95-
try:
96-
accumulated_array.to_zarr(
97-
f"{zarr_url}/0",
98-
overwrite=init_args.overwrite,
99-
dimension_separator="/",
100-
write_empty_chunks=False,
101-
)
102-
except ContainsArrayError as e:
103-
error_msg = (
104-
f"Cannot write array to zarr group at '{zarr_url}/0', "
105-
f"with {init_args.overwrite=} (original error: {str(e)}).\n"
106-
"Hint: try setting overwrite=True."
71+
original_ngff_image = NgffImage(init_args.origin_url)
72+
orginal_image = original_ngff_image.get_image()
73+
74+
if orginal_image.is_2d or orginal_image.is_2d_time_series:
75+
raise ValueError(
76+
"The input image is 2D, "
77+
"projection is only supported for 3D images."
10778
)
108-
logger.error(error_msg)
109-
raise OverwriteNotAllowedError(error_msg)
11079

111-
# Starting from on-disk highest-resolution data, build and write to disk a
112-
# pyramid of coarser levels
113-
build_pyramid(
114-
zarrurl=zarr_url,
80+
# Compute the new shape and pixel size
81+
dest_on_disk_shape = _compute_new_shape(orginal_image)
82+
83+
dest_pixel_size = orginal_image.pixel_size
84+
dest_pixel_size.z = 1.0
85+
ngio_logger.info(f"New shape: {dest_on_disk_shape=}")
86+
87+
# Create the new empty image
88+
new_ngff_image = original_ngff_image.derive_new_image(
89+
store=zarr_url,
90+
name="MIP",
91+
on_disk_shape=dest_on_disk_shape,
92+
pixel_sizes=dest_pixel_size,
11593
overwrite=init_args.overwrite,
116-
num_levels=ngff_image.num_levels,
117-
coarsening_xy=ngff_image.coarsening_xy,
118-
chunksize=(1, 1, chunksize_y, chunksize_x),
94+
)
95+
new_image = new_ngff_image.get_image()
96+
97+
# Process the image
98+
z_axis_index = orginal_image.find_axis("z")
99+
assert z_axis_index is not None # This should never happen since we checked for 3D images above
100+
101+
source_dask = orginal_image.get_array(
102+
mode="dask", preserve_dimensions=True
119103
)
120104

121-
# Copy over any tables from the original zarr
122-
# Generate the list of tables:
123-
tables = get_tables_list_v1(init_args.origin_url)
124-
roi_tables = get_tables_list_v1(init_args.origin_url, table_type="ROIs")
125-
non_roi_tables = [table for table in tables if table not in roi_tables]
126-
127-
for table in roi_tables:
128-
logger.info(
129-
f"Reading {table} from "
130-
f"{init_args.origin_url=}, convert it to 2D, and "
131-
"write it back to the new zarr file."
132-
)
133-
new_ROI_table = ad.read_zarr(f"{init_args.origin_url}/tables/{table}")
134-
old_ROI_table_attrs = zarr.open_group(
135-
f"{init_args.origin_url}/tables/{table}"
136-
).attrs.asdict()
137-
138-
# Convert 3D ROIs to 2D
139-
pxl_sizes_zyx = ngff_image.get_pixel_sizes_zyx(level=0)
140-
new_ROI_table = convert_ROIs_from_3D_to_2D(
141-
new_ROI_table, pixel_size_z=pxl_sizes_zyx[0]
142-
)
143-
# Write new table
144-
write_table(
145-
new_image_group,
146-
table,
147-
new_ROI_table,
148-
table_attrs=old_ROI_table_attrs,
149-
overwrite=init_args.overwrite,
105+
dest_dask = method.apply(dask_array=source_dask, axis=z_axis_index)
106+
dest_dask = da.expand_dims(dest_dask, axis=z_axis_index)
107+
new_image.set_array(dest_dask)
108+
new_image.consolidate()
109+
# Ends
110+
111+
# Copy over the tables
112+
for roi_table in original_ngff_image.tables.list(table_type="roi_table"):
113+
table = original_ngff_image.tables.get_table(roi_table)
114+
mip_table = new_ngff_image.tables.new(
115+
roi_table, table_type="roi_table", overwrite=True
150116
)
117+
118+
roi_list = []
119+
for roi in table.rois:
120+
roi.z_length = roi.z + 1
121+
roi_list.append(roi)
151122

152-
for table in non_roi_tables:
153-
logger.info(
154-
f"Reading {table} from "
155-
f"{init_args.origin_url=}, and "
156-
"write it back to the new zarr file."
157-
)
158-
new_non_ROI_table = ad.read_zarr(
159-
f"{init_args.origin_url}/tables/{table}"
160-
)
161-
old_non_ROI_table_attrs = zarr.open_group(
162-
f"{init_args.origin_url}/tables/{table}"
163-
).attrs.asdict()
164-
165-
# Write new table
166-
write_table(
167-
new_image_group,
168-
table,
169-
new_non_ROI_table,
170-
table_attrs=old_non_ROI_table_attrs,
171-
overwrite=init_args.overwrite,
172-
)
123+
mip_table.set_rois(roi_list, overwrite=True)
124+
mip_table.consolidate()
125+
ngio_logger.info(f"Table {roi_table} copied.")
173126

174127
# Generate image_list_updates
175128
image_list_update_dict = dict(
@@ -189,5 +142,5 @@ def projection(
189142

190143
run_fractal_task(
191144
task_function=projection,
192-
logger_name=logger.name,
193-
)
145+
logger_name=ngio_logger.name,
146+
)

tests/tasks/test_workflows.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121
import zarr
2222
from devtools import debug
23+
from ngio.utils import NgioFileExistsError
2324

2425
from ._validation import check_file_number
2526
from ._validation import validate_schema
@@ -373,15 +374,13 @@ def test_MIP(
373374
init_args=image["init_args"],
374375
)
375376

376-
# Re-run with overwrite=False
377-
with pytest.raises(OverwriteNotAllowedError):
377+
with pytest.raises(NgioFileExistsError):
378378
for image in parallelization_list:
379379
image["init_args"]["overwrite"] = False
380380
projection(
381381
zarr_url=image["zarr_url"],
382382
init_args=image["init_args"],
383383
)
384-
385384
# OME-NGFF JSON validation
386385
image_zarr = Path(parallelization_list[0]["zarr_url"])
387386
debug(image_zarr)

0 commit comments

Comments
 (0)