Skip to content

Commit 324e0d2

Browse files
fix bug when running 2D models on 1yx ome zarrs
1 parent 49e19ed commit 324e0d2

File tree

2 files changed

+41
-24
lines changed

2 files changed

+41
-24
lines changed

src/ilastik_tasks/ilastik_pixel_classification_segmentation.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,14 @@
2323
import json
2424
import logging
2525
import os
26-
from typing import Any, Optional
26+
from typing import Any
2727

2828
import anndata as ad
2929
import dask.array as da
3030
import fractal_tasks_core
3131
import numpy as np
3232
import vigra
3333
import zarr
34-
from ilastik_tasks.ilastik_utils import (
35-
IlastikChannel1InputModel,
36-
IlastikChannel2InputModel,
37-
get_expected_number_of_channels,
38-
)
3934
from fractal_tasks_core.labels import prepare_label_group
4035
from fractal_tasks_core.masked_loading import masked_loading_wrapper
4136
from fractal_tasks_core.ngff import load_NgffImageMeta
@@ -56,10 +51,16 @@
5651
from ilastik.applets.dataSelection.opDataSelection import (
5752
PreloadedArrayDatasetInfo,
5853
)
59-
from pydantic import validate_call, Field
54+
from pydantic import Field, validate_call
6055
from skimage.measure import label, regionprops
6156
from skimage.morphology import remove_small_holes
6257

58+
from ilastik_tasks.ilastik_utils import (
59+
IlastikChannel1InputModel,
60+
IlastikChannel2InputModel,
61+
get_expected_number_of_channels,
62+
)
63+
6364
logger = logging.getLogger(__name__)
6465

6566
__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__
@@ -82,7 +83,7 @@ def segment_ROI(
8283
foreground_class: int = 0,
8384
threshold: float = 0.5,
8485
min_size: int = 15,
85-
label_dtype: Optional[np.dtype] = None,
86+
label_dtype: np.dtype | None = None,
8687
relabeling: bool = True,
8788
) -> np.ndarray:
8889
"""Run the Ilastik model on a single ROI.
@@ -130,9 +131,7 @@ def segment_ROI(
130131
ilastik_labels = ilastik_output > threshold
131132

132133
# remove small holes
133-
ilastik_labels = remove_small_holes(
134-
ilastik_labels, area_threshold=min_size
135-
)
134+
ilastik_labels = remove_small_holes(ilastik_labels, area_threshold=min_size)
136135

137136
# label image
138137
ilastik_labels = label(ilastik_labels)
@@ -148,12 +147,16 @@ def segment_ROI(
148147
or (label_props[i].axis_major_length < 1)
149148
or (label_props[i].major_axis_length < 1)
150149
]
151-
logger.info(f"number of labels before filtering for size = {ilastik_labels.max()}")
150+
logger.info(
151+
f"number of labels before filtering for size = {ilastik_labels.max()}"
152+
)
152153
ilastik_labels[np.isin(ilastik_labels, labels2remove)] = 0
153154
ilastik_labels = label(ilastik_labels)
154-
logger.info(f"number of labels after filtering for size = {ilastik_labels.max()}")
155+
logger.info(
156+
f"number of labels after filtering for size = {ilastik_labels.max()}"
157+
)
155158
label_props = regionprops(ilastik_labels)
156-
159+
157160
# Shift labels and update relabeling counters
158161
if relabeling:
159162
num_labels_roi = np.max(ilastik_labels)
@@ -186,8 +189,8 @@ def ilastik_pixel_classification_segmentation(
186189
default_factory=IlastikChannel2InputModel
187190
),
188191
input_ROI_table: str = "FOV_ROI_table",
189-
output_ROI_table: Optional[str] = None,
190-
output_label_name: Optional[str] = None,
192+
output_ROI_table: str | None = None,
193+
output_label_name: str | None = None,
191194
use_masks: bool = True,
192195
# Ilastik-related arguments
193196
ilastik_model: str,
@@ -252,10 +255,10 @@ def ilastik_pixel_classification_segmentation(
252255

253256
# Setup Ilastik headless shell
254257
shell = setup_ilastik(ilastik_model)
255-
258+
256259
# Check if channel input fits expected number of channels of model
257260
expected_num_channels = get_expected_number_of_channels(shell)
258-
261+
259262
if expected_num_channels == 2 and not channel2.is_set():
260263
raise ValueError(
261264
"Ilastik model expects two channels as "
@@ -265,7 +268,7 @@ def ilastik_pixel_classification_segmentation(
265268
raise ValueError(
266269
"Ilastik model expects 1 channel as " "input but two channels were provided"
267270
)
268-
271+
269272
elif expected_num_channels > 2:
270273
raise NotImplementedError(
271274
f"Expected {expected_num_channels} channels, "
@@ -286,7 +289,7 @@ def ilastik_pixel_classification_segmentation(
286289
ind_channel_c2 = omero_channel_2.index
287290
else:
288291
return
289-
292+
290293
# Set channel label
291294
if output_label_name is None:
292295
try:
@@ -516,6 +519,21 @@ def ilastik_pixel_classification_segmentation(
516519
f"{len(overlap_list)} bounding-box pairs overlap"
517520
)
518521

522+
# Check that the shape of the new label image matches the expected shape
523+
expected_shape = (
524+
e_z - s_z,
525+
e_y - s_y,
526+
e_x - s_x,
527+
)
528+
if new_label_img.shape != expected_shape:
529+
try:
530+
new_label_img = da.broadcast_to(new_label_img, expected_shape)
531+
except:
532+
raise ValueError(
533+
f"Shape mismatch: {new_label_img.shape} != {expected_shape} "
534+
"Between the segmented label image and expected shape in the zarr array."
535+
)
536+
519537
# Compute and store 0-th level to disk
520538
da.array(new_label_img).to_zarr(
521539
url=mask_zarr,
@@ -564,7 +582,6 @@ def ilastik_pixel_classification_segmentation(
564582
)
565583

566584

567-
568585
if __name__ == "__main__":
569586
from fractal_task_tools.task_wrapper import run_fractal_task
570587

tests/test_ilastik_pixel_classification_segmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_ilastik_pixel_classification_segmentation_task_2D_single_channel(
126126
ilastik_pixel_classification_segmentation(
127127
zarr_url=ome_zarr_2d_url,
128128
level=1,
129-
channel=IlastikChannel1InputModel(label="DAPI_2"),
129+
channel=IlastikChannel1InputModel(label="DAPI"),
130130
channel2=IlastikChannel2InputModel(label=None),
131131
ilastik_model=str(ilastik_model),
132132
output_label_name="test_label",
@@ -139,8 +139,8 @@ def test_ilastik_pixel_classification_segmentation_task_2D_single_channel(
139139
ilastik_pixel_classification_segmentation(
140140
zarr_url=ome_zarr_2d_url,
141141
level=1,
142-
channel=IlastikChannel1InputModel(label="DAPI_2"),
143-
channel2=IlastikChannel2InputModel(label="ECadherin_2"),
142+
channel=IlastikChannel1InputModel(label="DAPI"),
143+
channel2=IlastikChannel2InputModel(label="ECadherin"),
144144
ilastik_model=str(ilastik_model),
145145
output_label_name="test_label",
146146
relabeling=True,

0 commit comments

Comments
 (0)