2323import json
2424import logging
2525import os
26- from typing import Any , Optional
26+ from typing import Any , Union
2727
2828import anndata as ad
2929import dask .array as da
3030import fractal_tasks_core
3131import numpy as np
3232import vigra
3333import zarr
34- from ilastik_tasks .ilastik_utils import (
35- IlastikChannel1InputModel ,
36- IlastikChannel2InputModel ,
37- get_expected_number_of_channels ,
38- )
3934from fractal_tasks_core .labels import prepare_label_group
4035from fractal_tasks_core .masked_loading import masked_loading_wrapper
4136from fractal_tasks_core .ngff import load_NgffImageMeta
5651from ilastik .applets .dataSelection .opDataSelection import (
5752 PreloadedArrayDatasetInfo ,
5853)
59- from pydantic import validate_call , Field
54+ from pydantic import Field , validate_call
6055from skimage .measure import label , regionprops
6156from skimage .morphology import remove_small_holes
6257
58+ from ilastik_tasks .ilastik_utils import (
59+ IlastikChannel1InputModel ,
60+ IlastikChannel2InputModel ,
61+ get_expected_number_of_channels ,
62+ )
63+
6364logger = logging .getLogger (__name__ )
6465
6566__OME_NGFF_VERSION__ = fractal_tasks_core .__OME_NGFF_VERSION__
@@ -82,7 +83,7 @@ def segment_ROI(
8283 foreground_class : int = 0 ,
8384 threshold : float = 0.5 ,
8485 min_size : int = 15 ,
85- label_dtype : Optional [np .dtype ] = None ,
86+ label_dtype : Union [np .dtype , None ] = None ,
8687 relabeling : bool = True ,
8788) -> np .ndarray :
8889 """Run the Ilastik model on a single ROI.
@@ -130,9 +131,7 @@ def segment_ROI(
130131 ilastik_labels = ilastik_output > threshold
131132
132133 # remove small holes
133- ilastik_labels = remove_small_holes (
134- ilastik_labels , area_threshold = min_size
135- )
134+ ilastik_labels = remove_small_holes (ilastik_labels , area_threshold = min_size )
136135
137136 # label image
138137 ilastik_labels = label (ilastik_labels )
@@ -148,12 +147,16 @@ def segment_ROI(
148147 or (label_props [i ].axis_major_length < 1 )
149148 or (label_props [i ].major_axis_length < 1 )
150149 ]
151- logger .info (f"number of labels before filtering for size = { ilastik_labels .max ()} " )
150+ logger .info (
151+ f"number of labels before filtering for size = { ilastik_labels .max ()} "
152+ )
152153 ilastik_labels [np .isin (ilastik_labels , labels2remove )] = 0
153154 ilastik_labels = label (ilastik_labels )
154- logger .info (f"number of labels after filtering for size = { ilastik_labels .max ()} " )
155+ logger .info (
156+ f"number of labels after filtering for size = { ilastik_labels .max ()} "
157+ )
155158 label_props = regionprops (ilastik_labels )
156-
159+
157160 # Shift labels and update relabeling counters
158161 if relabeling :
159162 num_labels_roi = np .max (ilastik_labels )
@@ -186,8 +189,8 @@ def ilastik_pixel_classification_segmentation(
186189 default_factory = IlastikChannel2InputModel
187190 ),
188191 input_ROI_table : str = "FOV_ROI_table" ,
189- output_ROI_table : Optional [str ] = None ,
190- output_label_name : Optional [str ] = None ,
192+ output_ROI_table : Union [str , None ] = None ,
193+ output_label_name : Union [str , None ] = None ,
191194 use_masks : bool = True ,
192195 # Ilastik-related arguments
193196 ilastik_model : str ,
@@ -252,10 +255,10 @@ def ilastik_pixel_classification_segmentation(
252255
253256 # Setup Ilastik headless shell
254257 shell = setup_ilastik (ilastik_model )
255-
258+
256259 # Check if channel input fits expected number of channels of model
257260 expected_num_channels = get_expected_number_of_channels (shell )
258-
261+
259262 if expected_num_channels == 2 and not channel2 .is_set ():
260263 raise ValueError (
261264 "Ilastik model expects two channels as "
@@ -265,7 +268,7 @@ def ilastik_pixel_classification_segmentation(
265268 raise ValueError (
266269 "Ilastik model expects 1 channel as " "input but two channels were provided"
267270 )
268-
271+
269272 elif expected_num_channels > 2 :
270273 raise NotImplementedError (
271274 f"Expected { expected_num_channels } channels, "
@@ -286,7 +289,7 @@ def ilastik_pixel_classification_segmentation(
286289 ind_channel_c2 = omero_channel_2 .index
287290 else :
288291 return
289-
292+
290293 # Set channel label
291294 if output_label_name is None :
292295 try :
@@ -516,6 +519,21 @@ def ilastik_pixel_classification_segmentation(
516519 f"{ len (overlap_list )} bounding-box pairs overlap"
517520 )
518521
522+ # Check that the shape of the new label image matches the expected shape
523+ expected_shape = (
524+ e_z - s_z ,
525+ e_y - s_y ,
526+ e_x - s_x ,
527+ )
528+ if new_label_img .shape != expected_shape :
529+ try :
530+ new_label_img = da .broadcast_to (new_label_img , expected_shape )
531+ except :
532+ raise ValueError (
533+ f"Shape mismatch: { new_label_img .shape } != { expected_shape } "
534+ "Between the segmented label image and expected shape in the zarr array."
535+ )
536+
519537 # Compute and store 0-th level to disk
520538 da .array (new_label_img ).to_zarr (
521539 url = mask_zarr ,
@@ -564,7 +582,6 @@ def ilastik_pixel_classification_segmentation(
564582 )
565583
566584
567-
568585if __name__ == "__main__" :
569586 from fractal_task_tools .task_wrapper import run_fractal_task
570587
0 commit comments