Skip to content

Commit 32c4da5

Browse files
Merge pull request #7 from rhornb/pred_error
bug fixes; code cleanup
2 parents c69ef4a + bd7cade commit 32c4da5

File tree

4 files changed

+217
-57
lines changed

4 files changed

+217
-57
lines changed

src/ilastik_tasks/__FRACTAL_MANIFEST__.json

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,34 @@
1515
},
1616
"args_schema_parallel": {
1717
"$defs": {
18-
"ChannelInputModel": {
19-
"description": "A channel which is specified by either `wavelength_id` or `label`.",
18+
"IlastikChannel1InputModel": {
19+
"description": "Channel input for ilastik.",
2020
"properties": {
2121
"wavelength_id": {
2222
"title": "Wavelength Id",
23-
"type": "string",
24-
"description": "Unique ID for the channel wavelength, e.g. `A01_C01`. Can only be specified if label is not set."
23+
"type": "string"
2524
},
2625
"label": {
2726
"title": "Label",
28-
"type": "string",
29-
"description": "Name of the channel. Can only be specified if wavelength_id is not set."
27+
"type": "string"
3028
}
3129
},
32-
"title": "ChannelInputModel",
30+
"title": "IlastikChannel1InputModel",
31+
"type": "object"
32+
},
33+
"IlastikChannel2InputModel": {
34+
"description": "Channel input for secondary ilastik channel.",
35+
"properties": {
36+
"wavelength_id": {
37+
"title": "Wavelength Id",
38+
"type": "string"
39+
},
40+
"label": {
41+
"title": "Label",
42+
"type": "string"
43+
}
44+
},
45+
"title": "IlastikChannel2InputModel",
3346
"type": "object"
3447
}
3548
},
@@ -46,16 +59,12 @@
4659
"description": "Pyramid level of the image to be segmented. Choose `0` to process at full resolution."
4760
},
4861
"channel": {
49-
"$ref": "#/$defs/ChannelInputModel",
62+
"$ref": "#/$defs/IlastikChannel1InputModel",
5063
"title": "Channel",
5164
"description": "Primary channel for pixel classification; requires either `wavelength_id` (e.g. `A01_C01`) or `label` (e.g. `DAPI`)."
5265
},
5366
"channel2": {
54-
"allOf": [
55-
{
56-
"$ref": "#/$defs/ChannelInputModel"
57-
}
58-
],
67+
"$ref": "#/$defs/IlastikChannel2InputModel",
5968
"title": "Channel2",
6069
"description": "Second channel for pixel classification (in the same format as `channel`). Use only if second channel has also been used during Ilastik model training."
6170
},
@@ -86,6 +95,12 @@
8695
"type": "string",
8796
"description": "Path to the Ilastik model (e.g. `\"somemodel.ilp\"`)."
8897
},
98+
"relabeling": {
99+
"default": true,
100+
"title": "Relabeling",
101+
"type": "boolean",
102+
"description": "If `True`, apply relabeling so that label values are unique for all objects in the well."
103+
},
89104
"foreground_class": {
90105
"default": 0,
91106
"title": "Foreground Class",

src/ilastik_tasks/ilastik_pixel_classification_segmentation.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131
import numpy as np
3232
import vigra
3333
import zarr
34-
from fractal_tasks_core.channels import ChannelInputModel, get_channel_from_image_zarr
34+
from ilastik_tasks.ilastik_utils import (
35+
IlastikChannel1InputModel,
36+
IlastikChannel2InputModel,
37+
)
3538
from fractal_tasks_core.labels import prepare_label_group
3639
from fractal_tasks_core.masked_loading import masked_loading_wrapper
3740
from fractal_tasks_core.ngff import load_NgffImageMeta
@@ -52,7 +55,7 @@
5255
from ilastik.applets.dataSelection.opDataSelection import (
5356
PreloadedArrayDatasetInfo,
5457
)
55-
from pydantic import validate_call
58+
from pydantic import validate_call, Field
5659
from skimage.measure import label, regionprops
5760
from skimage.morphology import remove_small_holes
5861

@@ -73,22 +76,29 @@ def setup_ilastik(model_path: str):
7376

7477
def segment_ROI(
7578
input_data: np.ndarray,
79+
num_labels_tot: dict[str, int],
7680
shell: Any,
7781
foreground_class: int = 0,
7882
threshold: float = 0.5,
7983
min_size: int = 15,
8084
label_dtype: Optional[np.dtype] = None,
85+
relabeling: bool = True,
8186
) -> np.ndarray:
8287
"""Run the Ilastik model on a single ROI.
8388
8489
Args:
8590
input_data: np.ndarray of shape (t, z, y, x, c).
91+
num_labels_tot: Number of labels already in total image. Used for
92+
relabeling purposes. Using a dict to have a mutable object that
93+
can be edited from within the function without having to be passed
94+
back through the masked_loading_wrapper.
8695
shell: Ilastik headless shell.
8796
foreground_class: Class to be considered as foreground
8897
during prediction thresholding.
8998
threshold: Threshold for the Ilastik model.
9099
min_size: Minimum size for the Ilastik model.
91100
label_dtype: Label images are cast into this `np.dtype`.
101+
relabeling: Whether relabeling based on num_labels_tot is performed.
92102
93103
Returns:
94104
np.ndarray: Segmented image. Shape (z, y, x).
@@ -97,7 +107,7 @@ def segment_ROI(
97107
# Shape from (czyx) to (tzyxc)
98108
input_data = np.moveaxis(input_data, 0, -1)
99109
input_data = np.expand_dims(input_data, axis=0)
100-
print(f"{input_data.shape=}")
110+
logger.info(f"{input_data.shape=}")
101111
data = [
102112
{
103113
"Raw Data": PreloadedArrayDatasetInfo(
@@ -109,26 +119,22 @@ def segment_ROI(
109119
ilastik_output = shell.workflow.batchProcessingApplet.run_export(
110120
data, export_to_array=True
111121
)[0]
112-
logger.info(f"{ilastik_output.shape=}")
122+
logger.info(f"{ilastik_output.shape=} after ilastik prediction")
113123

114124
# Get foreground class and reshape to 3D
115-
ilastik_output = ilastik_output[..., foreground_class]
116-
ilastik_output = np.reshape(
117-
ilastik_output, (input_data.shape[1], input_data.shape[2], input_data.shape[3])
118-
)
119-
logger.info(f"{ilastik_output.shape=}")
125+
ilastik_output = np.squeeze(ilastik_output[..., foreground_class])
126+
logger.info(f"{ilastik_output.shape=} after foreground class selection")
120127

121128
# take mask of regions above threshold
122-
ilastik_output[ilastik_output < threshold] = 0
123-
ilastik_output[ilastik_output >= threshold] = 1
129+
ilastik_labels = ilastik_output > threshold
124130

125131
# remove small holes
126-
ilastik_output = remove_small_holes(
127-
ilastik_output.astype(bool), area_threshold=min_size
132+
ilastik_labels = remove_small_holes(
133+
ilastik_labels, area_threshold=min_size
128134
)
129135

130136
# label image
131-
ilastik_labels = label(ilastik_output)
137+
ilastik_labels = label(ilastik_labels)
132138

133139
# remove objects below min_size - also removes anything with major or minor axis
134140
# length of 0 for compatibility with current measurements task (01.24)
@@ -141,11 +147,28 @@ def segment_ROI(
141147
or (label_props[i].axis_major_length < 1)
142148
or (label_props[i].major_axis_length < 1)
143149
]
144-
print(f"number of labels before filtering for size = {ilastik_labels.max()}")
150+
logger.info(f"number of labels before filtering for size = {ilastik_labels.max()}")
145151
ilastik_labels[np.isin(ilastik_labels, labels2remove)] = 0
146152
ilastik_labels = label(ilastik_labels)
147-
print(f"number of labels after filtering for size = {ilastik_labels.max()}")
153+
logger.info(f"number of labels after filtering for size = {ilastik_labels.max()}")
148154
label_props = regionprops(ilastik_labels)
155+
156+
# Shift labels and update relabeling counters
157+
if relabeling:
158+
num_labels_roi = np.max(ilastik_labels)
159+
ilastik_labels[ilastik_labels > 0] += num_labels_tot["num_labels_tot"]
160+
num_labels_tot["num_labels_tot"] += num_labels_roi
161+
162+
# Write some logs
163+
logger.info(f"ROI had {num_labels_roi=}, {num_labels_tot=}")
164+
165+
# Check that total number of labels is under control
166+
if num_labels_tot["num_labels_tot"] > np.iinfo(label_dtype).max:
167+
raise ValueError(
168+
"ERROR in re-labeling:"
169+
f"Reached {num_labels_tot} labels, "
170+
f"but dtype={label_dtype}"
171+
)
149172

150173
return ilastik_labels.astype(label_dtype)
151174

@@ -157,14 +180,17 @@ def ilastik_pixel_classification_segmentation(
157180
zarr_url: str,
158181
# Task-specific arguments
159182
level: int,
160-
channel: ChannelInputModel,
161-
channel2: Optional[ChannelInputModel] = None,
183+
channel: IlastikChannel1InputModel,
184+
channel2: IlastikChannel2InputModel = Field(
185+
default_factory=IlastikChannel2InputModel
186+
),
162187
input_ROI_table: str = "FOV_ROI_table",
163188
output_ROI_table: Optional[str] = None,
164189
output_label_name: Optional[str] = None,
165190
use_masks: bool = True,
166191
# Ilastik-related arguments
167192
ilastik_model: str,
193+
relabeling: bool = True,
168194
foreground_class: int = 0,
169195
threshold: float = 0.5,
170196
min_size: int = 15,
@@ -196,6 +222,8 @@ def ilastik_pixel_classification_segmentation(
196222
loading is relevant when only a subset of the bounding box should
197223
actually be processed (e.g. running within `emb_ROI_table`).
198224
ilastik_model: Path to the Ilastik model (e.g. `"somemodel.ilp"`).
225+
relabeling: If `True`, apply relabeling so that label values are
226+
unique for all objects in the well.
199227
foreground_class: Class to be considered as foreground during
200228
prediction thresholding.
201229
threshold: Probabiltiy threshold for the Ilastik model.
@@ -237,32 +265,24 @@ def ilastik_pixel_classification_segmentation(
237265
)
238266

239267
# Find channel index
240-
tmp_channel = get_channel_from_image_zarr(
241-
image_zarr_path=zarr_url,
242-
wavelength_id=channel.wavelength_id,
243-
label=channel.label,
244-
)
245-
if tmp_channel:
246-
ind_channel = tmp_channel.index
268+
omero_channel = channel.get_omero_channel(zarr_url)
269+
if omero_channel:
270+
ind_channel = omero_channel.index
247271
else:
248272
return
249273

250274
# Find channel index for second channel, if one is provided
251-
if channel2:
252-
tmp_channel_2 = get_channel_from_image_zarr(
253-
image_zarr_path=zarr_url,
254-
wavelength_id=channel2.wavelength_id,
255-
label=channel2.label,
256-
)
257-
if tmp_channel_2:
258-
ind_channel_c2 = tmp_channel.index
275+
if channel2.is_set():
276+
omero_channel_2 = channel2.get_omero_channel(zarr_url)
277+
if omero_channel_2:
278+
ind_channel_c2 = omero_channel_2.index
259279
else:
260-
return ValueError(f"Channel {channel2} could not be loaded.")
261-
262-
# Set output channel label if none is provided
280+
return
281+
282+
# Set channel label
263283
if output_label_name is None:
264284
try:
265-
channel_label = tmp_channel.label
285+
channel_label = omero_channel.label
266286
output_label_name = f"label_{channel_label}"
267287
except (KeyError, IndexError):
268288
output_label_name = f"label_{ind_channel}"
@@ -392,9 +412,13 @@ def ilastik_pixel_classification_segmentation(
392412

393413
# Initialize other things
394414
logger.info(f"Start ilastik pixel classification task for {zarr_url}")
415+
logger.info(f"relabeling: {relabeling}")
395416
logger.info(f"{data_zyx.shape}")
396417
logger.info(f"{data_zyx.chunks}")
397418

419+
# Counters for relabeling
420+
num_labels_tot = {"num_labels_tot": 0}
421+
398422
# Iterate over ROIs
399423
num_ROIs = len(list_indices)
400424

@@ -440,10 +464,12 @@ def ilastik_pixel_classification_segmentation(
440464
# Prepare keyword arguments for segment_ROI function
441465
kwargs_segment_ROI = {
442466
"shell": shell,
467+
"num_labels_tot": num_labels_tot,
443468
"foreground_class": foreground_class,
444469
"threshold": threshold,
445470
"min_size": min_size,
446471
"label_dtype": label_dtype,
472+
"relabeling": relabeling,
447473
}
448474

449475
# Prepare keyword arguments for preprocessing function

0 commit comments

Comments
 (0)