1212"""
1313Task for 3D->2D maximum-intensity projection.
1414"""
15+ from __future__ import annotations
16+
1517import logging
1618from typing import Any
1719
18- import anndata as ad
1920import dask .array as da
20- import zarr
21+ from ngio import NgffImage
22+ from ngio .core import Image
2123from pydantic import validate_call
22- from zarr .errors import ContainsArrayError
23-
24- from fractal_tasks_core .ngff import load_NgffImageMeta
25- from fractal_tasks_core .pyramids import build_pyramid
26- from fractal_tasks_core .roi import (
27- convert_ROIs_from_3D_to_2D ,
28- )
29- from fractal_tasks_core .tables import write_table
30- from fractal_tasks_core .tables .v1 import get_tables_list_v1
24+
3125from fractal_tasks_core .tasks .io_models import InitArgsMIP
3226from fractal_tasks_core .tasks .projection_utils import DaskProjectionMethod
33- from fractal_tasks_core .zarr_utils import OverwriteNotAllowedError
34-
3527
3628logger = logging .getLogger (__name__ )
3729
3830
31+ def _compute_new_shape (source_image : Image ) -> tuple [int ]:
32+ """Compute the new shape of the image after the projection.
33+
34+ The new shape is the same as the original one,
35+ except for the z-axis, which is set to 1.
36+ """
37+ on_disk_shape = source_image .on_disk_shape
38+ logger .info (f"Source { on_disk_shape = } " )
39+
40+ on_disk_z_index = source_image .dataset .on_disk_axes_names .index ("z" )
41+
42+ dest_on_disk_shape = list (on_disk_shape )
43+ dest_on_disk_shape [on_disk_z_index ] = 1
44+ logger .info (f"Destination { dest_on_disk_shape = } " )
45+ return tuple (dest_on_disk_shape )
46+
47+
3948@validate_call
4049def projection (
4150 * ,
@@ -60,123 +69,68 @@ def projection(
6069 logger .info (f"{ method = } " )
6170
6271 # Read image metadata
63- ngff_image = load_NgffImageMeta (init_args .origin_url )
64- # Currently not using the validation models due to wavelength_id issue
65- # See #681 for discussion
66- # new_attrs = ngff_image.model_dump(exclude_none=True)
67- # Current way to get the necessary metadata for MIP
68- group = zarr .open_group (init_args .origin_url , mode = "r" )
69- new_attrs = group .attrs .asdict ()
70-
71- # Create the zarr image with correct
72- new_image_group = zarr .group (zarr_url )
73- new_image_group .attrs .put (new_attrs )
74-
75- # Load 0-th level
76- data_czyx = da .from_zarr (init_args .origin_url + "/0" )
77- num_channels = data_czyx .shape [0 ]
78- chunksize_y = data_czyx .chunksize [- 2 ]
79- chunksize_x = data_czyx .chunksize [- 1 ]
80- logger .info (f"{ num_channels = } " )
81- logger .info (f"{ chunksize_y = } " )
82- logger .info (f"{ chunksize_x = } " )
83-
84- # Loop over channels
85- accumulate_chl = []
86- for ind_ch in range (num_channels ):
87- # Perform MIP for each channel of level 0
88- project_yx = da .stack (
89- [method .apply (data_czyx [ind_ch ], axis = 0 )], axis = 0
90- )
91- accumulate_chl .append (project_yx )
92- accumulated_array = da .stack (accumulate_chl , axis = 0 )
93-
94- # Write to disk (triggering execution)
95- try :
96- accumulated_array .to_zarr (
97- f"{ zarr_url } /0" ,
98- overwrite = init_args .overwrite ,
99- dimension_separator = "/" ,
100- write_empty_chunks = False ,
101- )
102- except ContainsArrayError as e :
103- error_msg = (
104- f"Cannot write array to zarr group at '{ zarr_url } /0', "
105- f"with { init_args .overwrite = } (original error: { str (e )} ).\n "
106- "Hint: try setting overwrite=True."
72+ original_ngff_image = NgffImage (init_args .origin_url )
73+ orginal_image = original_ngff_image .get_image ()
74+
75+ if orginal_image .is_2d or orginal_image .is_2d_time_series :
76+ raise ValueError (
77+ "The input image is 2D, "
78+ "projection is only supported for 3D images."
10779 )
108- logger .error (error_msg )
109- raise OverwriteNotAllowedError (error_msg )
11080
111- # Starting from on-disk highest-resolution data, build and write to disk a
112- # pyramid of coarser levels
113- build_pyramid (
114- zarrurl = zarr_url ,
81+ # Compute the new shape and pixel size
82+ dest_on_disk_shape = _compute_new_shape (orginal_image )
83+
84+ dest_pixel_size = orginal_image .pixel_size
85+ dest_pixel_size .z = 1.0
86+ logger .info (f"New shape: { dest_on_disk_shape = } " )
87+
88+ # Create the new empty image
89+ new_ngff_image = original_ngff_image .derive_new_image (
90+ store = zarr_url ,
91+ name = "MIP" ,
92+ on_disk_shape = dest_on_disk_shape ,
93+ pixel_sizes = dest_pixel_size ,
11594 overwrite = init_args .overwrite ,
116- num_levels = ngff_image .num_levels ,
117- coarsening_xy = ngff_image .coarsening_xy ,
118- chunksize = (1 , 1 , chunksize_y , chunksize_x ),
95+ copy_labels = False ,
96+ copy_tables = True ,
11997 )
98+ logger .info (f"New Projection image created - { new_ngff_image = } " )
99+ new_image = new_ngff_image .get_image ()
120100
121- # Copy over any tables from the original zarr
122- # Generate the list of tables:
123- tables = get_tables_list_v1 (init_args .origin_url )
124- roi_tables = get_tables_list_v1 (init_args .origin_url , table_type = "ROIs" )
125- non_roi_tables = [table for table in tables if table not in roi_tables ]
126-
127- for table in roi_tables :
128- logger .info (
129- f"Reading { table } from "
130- f"{ init_args .origin_url = } , convert it to 2D, and "
131- "write it back to the new zarr file."
132- )
133- new_ROI_table = ad .read_zarr (f"{ init_args .origin_url } /tables/{ table } " )
134- old_ROI_table_attrs = zarr .open_group (
135- f"{ init_args .origin_url } /tables/{ table } "
136- ).attrs .asdict ()
137-
138- # Convert 3D ROIs to 2D
139- pxl_sizes_zyx = ngff_image .get_pixel_sizes_zyx (level = 0 )
140- new_ROI_table = convert_ROIs_from_3D_to_2D (
141- new_ROI_table , pixel_size_z = pxl_sizes_zyx [0 ]
142- )
143- # Write new table
144- write_table (
145- new_image_group ,
146- table ,
147- new_ROI_table ,
148- table_attrs = old_ROI_table_attrs ,
149- overwrite = init_args .overwrite ,
150- )
101+ # Process the image
102+ z_axis_index = orginal_image .find_axis ("z" )
103+ source_dask = orginal_image .get_array (
104+ mode = "dask" , preserve_dimensions = True
105+ )
151106
152- for table in non_roi_tables :
153- logger .info (
154- f"Reading { table } from "
155- f"{ init_args .origin_url = } , and "
156- "write it back to the new zarr file."
157- )
158- new_non_ROI_table = ad .read_zarr (
159- f"{ init_args .origin_url } /tables/{ table } "
160- )
161- old_non_ROI_table_attrs = zarr .open_group (
162- f"{ init_args .origin_url } /tables/{ table } "
163- ).attrs .asdict ()
164-
165- # Write new table
166- write_table (
167- new_image_group ,
168- table ,
169- new_non_ROI_table ,
170- table_attrs = old_non_ROI_table_attrs ,
171- overwrite = init_args .overwrite ,
172- )
107+ dest_dask = method .apply (dask_array = source_dask , axis = z_axis_index )
108+ dest_dask = da .expand_dims (dest_dask , axis = z_axis_index )
109+ new_image .set_array (dest_dask )
110+ new_image .consolidate ()
111+ # Ends
112+
113+ # Copy over the tables
114+ for roi_table_name in new_ngff_image .tables .list (table_type = "roi_table" ):
115+ table = new_ngff_image .tables .get_table (roi_table_name )
116+
117+ roi_list = []
118+ for roi in table .rois :
119+ roi .z = 0.0
120+ roi .z_length = 1.0
121+ roi_list .append (roi )
122+
123+ table .set_rois (roi_list , overwrite = True )
124+ table .consolidate ()
125+ logger .info (f"Table { roi_table_name } Projection done" )
173126
174127 # Generate image_list_updates
175128 image_list_update_dict = dict (
176129 image_list_updates = [
177130 dict (
178131 zarr_url = zarr_url ,
179132 origin = init_args .origin_url ,
133+ attributes = dict (plate = init_args .new_plate_name ),
180134 types = dict (is_3D = False ),
181135 )
182136 ]
0 commit comments