Skip to content

Commit 01d4170

Browse files
authored
Merge pull request #7 from rhornb/single_channel
Single channel input model fix
2 parents 41ae86e + bc14bba commit 01d4170

File tree

6 files changed

+81
-35
lines changed

6 files changed

+81
-35
lines changed

.github/workflows/build_and_test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ jobs:
1717
strategy:
1818
matrix:
1919
os: [ubuntu-22.04, macos-latest]
20-
python-version: ["3.10"]
20+
python-version: ["3.11"]
2121
name: "Core, Python ${{ matrix.python-version }}, ${{ matrix.os }}"
2222
runs-on: ${{ matrix.os }}
23-
timeout-minutes: 10
23+
timeout-minutes: 30
2424

2525
steps:
2626
- uses: actions/checkout@v4
@@ -29,7 +29,7 @@ jobs:
2929

3030
- uses: prefix-dev/[email protected]
3131
with:
32-
pixi-version: v0.44.0
32+
pixi-version: v0.50.2
3333

3434
- name: Check if manifest has changed
3535
run: pixi run fractal-manifest check --package ilastik-tasks

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ docstring_parser = "==0.15"
118118
ilastik-core = ">=1.4.2a1,<2"
119119
vigra = ">=1.12.1,<2"
120120
fractal-tasks-core = "==1.5.3"
121-
fractal-task-tools = "==0.0.12"
121+
fractal-task-tools = "==0.0.12"

src/ilastik_tasks/ilastik_pixel_classification_segmentation.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ilastik_tasks.ilastik_utils import (
3535
IlastikChannel1InputModel,
3636
IlastikChannel2InputModel,
37+
get_expected_number_of_channels,
3738
)
3839
from fractal_tasks_core.labels import prepare_label_group
3940
from fractal_tasks_core.masked_loading import masked_loading_wrapper
@@ -251,18 +252,25 @@ def ilastik_pixel_classification_segmentation(
251252

252253
# Setup Ilastik headless shell
253254
shell = setup_ilastik(ilastik_model)
254-
255-
# Check model channel requirements
256-
expected_channels = check_ilastik_model_channels(shell)
257-
if expected_channels == 2 and channel2 is None:
255+
256+
# Check if channel input fits expected number of channels of model
257+
expected_num_channels = get_expected_number_of_channels(shell)
258+
259+
if expected_num_channels == 2 and not channel2.is_set():
258260
raise ValueError(
259261
"Ilastik model expects two channels as "
260262
"input but only one channel was provided"
261263
)
262-
elif expected_channels == 1 and channel2 is not None:
264+
elif expected_num_channels == 1 and channel2.is_set():
263265
raise ValueError(
264266
"Ilastik model expects 1 channel as " "input but two channels were provided"
265267
)
268+
269+
elif expected_num_channels > 2:
270+
raise NotImplementedError(
271+
f"Expected {expected_num_channels} channels, "
272+
"but a maximum of channels are currently supported."
273+
)
266274

267275
# Find channel index
268276
omero_channel = channel.get_omero_channel(zarr_url)
@@ -290,7 +298,7 @@ def ilastik_pixel_classification_segmentation(
290298
# Load ZYX data
291299
data_zyx = da.from_zarr(f"{zarr_url}/{level}")[ind_channel]
292300
logger.info(f"{data_zyx.shape=}")
293-
if channel2:
301+
if channel2.is_set():
294302
data_zyx_c2 = da.from_zarr(f"{zarr_url}/{level}")[ind_channel_c2]
295303
logger.info(f"Second channel: {data_zyx_c2.shape=}")
296304

@@ -437,7 +445,7 @@ def ilastik_pixel_classification_segmentation(
437445
logger.info(f"Now processing ROI {i_ROI+1}/{num_ROIs}")
438446

439447
# Prepare single-channel or dual-channel input for Ilastik
440-
if channel2:
448+
if channel2.is_set():
441449
# Dual channel mode
442450
img_1 = load_region(
443451
data_zyx,
@@ -556,24 +564,6 @@ def ilastik_pixel_classification_segmentation(
556564
)
557565

558566

559-
def check_ilastik_model_channels(shell) -> int:
560-
"""Check number of input channels expected by Ilastik model.
561-
562-
Args:
563-
shell: Initialized Ilastik shell with loaded model
564-
565-
Returns:
566-
int: Number of expected input channels
567-
"""
568-
# Get dataSelection applet from workflow
569-
data_selection = shell.workflow.dataSelectionApplet
570-
571-
# Get slot info containing expected channels
572-
slot_info = data_selection.topLevelOperator.DatasetRoles.value
573-
574-
# Return number of expected channels
575-
return len(slot_info)
576-
577567

578568
if __name__ == "__main__":
579569
from fractal_task_tools.task_wrapper import run_fractal_task

src/ilastik_tasks/ilastik_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,20 @@ def get_omero_channel(self, zarr_url) -> OmeroChannel:
111111
f"Original error: {str(e)}"
112112
)
113113
return None
114+
115+
116+
def get_expected_number_of_channels(shell) -> int:
117+
"""
118+
Get the expected number of channels from the trained ilastik model
119+
"""
120+
opPixelClassification = shell.workflow.pcApplet.topLevelOperator
121+
len_input_images = len(opPixelClassification.InputImages)
122+
channel_number = []
123+
for i in range(len_input_images):
124+
channel = opPixelClassification.InputImages[i].meta.getTaggedShape()["c"]
125+
channel_number.append(channel)
126+
127+
if len(set(channel_number)) != 1:
128+
raise ValueError("Inconsistent number of channels across input images.")
129+
130+
return channel_number[0]
366 KB
Binary file not shown.

tests/test_ilastik_pixel_classification_segmentation.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33

44
import pytest
55
from devtools import debug
6+
7+
from ilastik_tasks.ilastik_pixel_classification_segmentation import (
8+
ilastik_pixel_classification_segmentation,
9+
)
610
from ilastik_tasks.ilastik_utils import (
711
IlastikChannel1InputModel,
812
IlastikChannel2InputModel,
913
)
10-
from ilastik_tasks.ilastik_pixel_classification_segmentation import (
11-
ilastik_pixel_classification_segmentation,
12-
)
1314

1415
# TODO: add 2D testdata
1516

17+
1618
@pytest.fixture(scope="function")
1719
def test_data_dir_3d(tmp_path: Path, zenodo_zarr_3d: list) -> str:
1820
"""
@@ -26,7 +28,9 @@ def test_data_dir_3d(tmp_path: Path, zenodo_zarr_3d: list) -> str:
2628
return dest_dir
2729

2830

29-
def test_ilastik_pixel_classification_segmentation_task_3D(test_data_dir_3d):
31+
def test_ilastik_pixel_classification_segmentation_task_3D_dual_channel(
32+
test_data_dir_3d,
33+
):
3034
"""
3135
Test the 3D ilastik_pixel_classification_segmentation task with dual channel input.
3236
"""
@@ -52,9 +56,44 @@ def test_ilastik_pixel_classification_segmentation_task_3D(test_data_dir_3d):
5256
zarr_url=zarr_url,
5357
level=4,
5458
channel=IlastikChannel1InputModel(label="DAPI_2"),
55-
channel2=None,
59+
channel2=IlastikChannel2InputModel(label=None),
5660
ilastik_model=str(ilastik_model),
57-
output_label_name="test_label_single_channel",
61+
output_label_name="test_label",
62+
relabeling=True,
5863
)
5964

6065

66+
def test_ilastik_pixel_classification_segmentation_task_3D_single_channel(
67+
test_data_dir_3d,
68+
):
69+
"""
70+
Test the 3D ilastik_pixel_classification_segmentation task
71+
with single channel input.
72+
"""
73+
ilastik_model = (
74+
Path(__file__).parent / "data/pixel_classifier_3D_single_channel.ilp"
75+
).as_posix()
76+
zarr_url = f"{test_data_dir_3d}/B/03/0"
77+
78+
ilastik_pixel_classification_segmentation(
79+
zarr_url=zarr_url,
80+
level=4,
81+
channel=IlastikChannel1InputModel(label="DAPI_2"),
82+
channel2=IlastikChannel2InputModel(label=None),
83+
ilastik_model=str(ilastik_model),
84+
output_label_name="test_label",
85+
relabeling=True,
86+
)
87+
88+
# Test failing of task if model was trained with one channel
89+
# but two are provided
90+
with pytest.raises(ValueError):
91+
ilastik_pixel_classification_segmentation(
92+
zarr_url=zarr_url,
93+
level=4,
94+
channel=IlastikChannel1InputModel(label="DAPI_2"),
95+
channel2=IlastikChannel2InputModel(label="ECadherin_2"),
96+
ilastik_model=str(ilastik_model),
97+
output_label_name="test_label",
98+
relabeling=True,
99+
)

0 commit comments

Comments
 (0)