Skip to content

Commit 82b271a

Browse files
committed
test data added, channel input ilastik check
1 parent c045e88 commit 82b271a

File tree

6 files changed

+84
-50
lines changed

6 files changed

+84
-50
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,8 @@ ENV/
109109

110110
# IDE settings
111111
.vscode/
112-
.idea/
112+
.idea/
113+
114+
*.zarr
115+
# Test data loaded via pooch
116+
/tests/data/10_5281_zenodo_14883998

create_env_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ $COMMMAND run $COMMMAND create $LOCATION \
1919
--override-channels \
2020
-c pytorch \
2121
-c ilastik-forge \
22-
-c conda-forge $PYTHON ilastik \
22+
-c conda-forge $PYTHON ilastik vigra \
2323
--no-channel-priority --yes
2424

2525
echo "Installing ilastik-tasks version $VERSION"

src/ilastik_tasks/ilastik_pixel_classification_segmentation.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,17 @@ def ilastik_pixel_classification_segmentation(
221221
)
222222

223223
# Setup Ilastik headless shell
224-
# TODO: check if info about training mode can be retreived e.g. if model needs
225-
# single or dual channel input and throw error if wrong input is provided
226224
shell = setup_ilastik(ilastik_model)
227225

226+
# Check model channel requirements
227+
expected_channels = check_ilastik_model_channels(shell)
228+
if expected_channels == 2 and channel2 is None:
229+
raise ValueError(f"Ilastik model expects two channels as "
230+
"input but only one channel was provided")
231+
elif expected_channels == 1 and channel2 is not None:
232+
raise ValueError(f"Ilastik model expects 1 channel as "
233+
"input but two channels were provided")
234+
228235
# Find channel index
229236
tmp_channel = get_channel_from_image_zarr(
230237
image_zarr_path=zarr_url,
@@ -526,6 +533,25 @@ def ilastik_pixel_classification_segmentation(
526533
)
527534

528535

536+
def check_ilastik_model_channels(shell) -> int:
537+
"""Check number of input channels expected by Ilastik model.
538+
539+
Args:
540+
shell: Initialized Ilastik shell with loaded model
541+
542+
Returns:
543+
int: Number of expected input channels
544+
"""
545+
# Get dataSelection applet from workflow
546+
data_selection = shell.workflow.dataSelectionApplet
547+
548+
# Get slot info containing expected channels
549+
slot_info = data_selection.topLevelOperator.DatasetRoles.value
550+
551+
# Return number of expected channels
552+
return len(slot_info)
553+
554+
529555
if __name__ == "__main__":
530556
from fractal_tasks_core.tasks._utils import run_fractal_task
531557

tests/conftest.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
# copied from
2-
# https://github.com/fractal-analytics-platform/fractal-helper-tasks/blob/main/tests/conftest.py
3-
41
import os
52
import shutil
63
from pathlib import Path
74

85
import pooch
96
import pytest
10-
7+
from devtools import debug
118

129
@pytest.fixture(scope="session")
1310
def testdata_path() -> Path:
@@ -16,7 +13,7 @@ def testdata_path() -> Path:
1613

1714

1815
@pytest.fixture(scope="session")
19-
def zenodo_zarr(testdata_path: Path) -> list[str]:
16+
def zenodo_zarr_3d(testdata_path: Path) -> str:
2017
"""
2118
This takes care of multiple steps:
2219
@@ -25,16 +22,14 @@ def zenodo_zarr(testdata_path: Path) -> list[str]:
2522
3. Modify the Zarrs in tests/data, to add whatever is not in Zenodo
2623
"""
2724

28-
# 1 Download Zarrs from Zenodo
29-
DOI = "10.5281/zenodo.10257149"
25+
# 1) Download Zarrs from Zenodo
26+
DOI = "10.5281/zenodo.14883998"
3027
DOI_slug = DOI.replace("/", "_").replace(".", "_")
31-
platenames = ["plate.zarr", "plate_mip.zarr"]
3228
rootfolder = testdata_path / DOI_slug
33-
folders = [rootfolder / plate for plate in platenames]
29+
folder = rootfolder / "AssayPlate_Greiner_CELLSTAR655090.zarr"
3430

3531
registry = {
36-
"20200812-CardiomyocyteDifferentiation14-Cycle1.zarr.zip": None,
37-
"20200812-CardiomyocyteDifferentiation14-Cycle1_mip.zarr.zip": None,
32+
"AssayPlate_Greiner_CELLSTAR655090.zarr.zip": None,
3833
}
3934
base_url = f"doi:{DOI}"
4035
POOCH = pooch.create(
@@ -45,31 +40,16 @@ def zenodo_zarr(testdata_path: Path) -> list[str]:
4540
allow_updates=False,
4641
)
4742

48-
for ind, file_name in enumerate(
49-
[
50-
"20200812-CardiomyocyteDifferentiation14-Cycle1.zarr",
51-
"20200812-CardiomyocyteDifferentiation14-Cycle1_mip.zarr",
52-
]
53-
):
54-
# 1) Download/unzip a single Zarr from Zenodo
55-
file_paths = POOCH.fetch(
56-
f"{file_name}.zip", processor=pooch.Unzip(extract_dir=file_name)
57-
)
58-
zarr_full_path = file_paths[0].split(file_name)[0] + file_name
59-
folder = folders[ind]
60-
61-
# 2) Copy the downloaded Zarr into tests/data
62-
if os.path.isdir(str(folder)):
63-
shutil.rmtree(str(folder))
64-
shutil.copytree(Path(zarr_full_path) / file_name, folder)
65-
return [str(f) for f in folders]
43+
file_name = "AssayPlate_Greiner_CELLSTAR655090.zarr"
44+
# 2) Download/unzip a single Zarr from Zenodo
45+
file_paths = POOCH.fetch(
46+
f"{file_name}.zip", processor=pooch.Unzip(extract_dir=file_name)
47+
)
48+
zarr_full_path = file_paths[0].split(file_name)[0] + file_name
6649

50+
# 3) Copy the downloaded Zarr into tests/data
51+
if os.path.isdir(str(folder)):
52+
shutil.rmtree(str(folder))
53+
shutil.copytree(Path(zarr_full_path) / file_name, folder)
54+
return Path(folder)
6755

68-
@pytest.fixture(scope="function")
69-
def tmp_zenodo_zarr(zenodo_zarr: list[str], tmpdir: Path) -> list[str]:
70-
"""Generates a copy of the zenodo zarrs in a tmpdir"""
71-
zenodo_mip_path = str(tmpdir / Path(zenodo_zarr[1]).name)
72-
zenodo_path = str(tmpdir / Path(zenodo_zarr[0]).name)
73-
shutil.copytree(zenodo_zarr[0], zenodo_path)
74-
shutil.copytree(zenodo_zarr[1], zenodo_mip_path)
75-
return [zenodo_path, zenodo_mip_path]
1.04 MB
Binary file not shown.

tests/test_ilastik_pixel_classification_segmentation.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,52 @@
22
from pathlib import Path
33

44
import pytest
5+
import zarr
56
from devtools import debug
7+
68
from fractal_tasks_core.channels import ChannelInputModel
79

810
from ilastik_tasks.ilastik_pixel_classification_segmentation import (
911
ilastik_pixel_classification_segmentation,
1012
)
1113

12-
def test_ilastik_pixel_classification_segmentation_task_2D_single_channel(tmp_zenodo_zarr: list[str]):
14+
# TODO: add 2D testdata
15+
16+
@pytest.fixture(scope="function")
17+
def test_data_dir_3d(tmp_path: Path, zenodo_zarr_3d: list) -> str:
18+
"""
19+
Copy a test-data folder into a temporary folder.
20+
"""
21+
dest_dir = (tmp_path / "ilastik_data_3d").as_posix()
22+
debug(zenodo_zarr_3d, dest_dir)
23+
shutil.copytree(zenodo_zarr_3d, dest_dir)
24+
return dest_dir
25+
26+
27+
def test_ilastik_pixel_classification_segmentation_task_3D(test_data_dir_3d):
1328
"""
14-
Test the 2D (e.g. MIP) ilastik_pixel_classification_segmentation task with single channel input.
29+
Test the 3D ilastik_pixel_classification_segmentation task with dual channel input.
1530
"""
16-
ilastik_model = (Path(__file__).parent / "data/pixel_classifier_2D.ilp").as_posix()
17-
zarr_url = f"{tmp_zenodo_zarr[1]}/B/03/0"
31+
ilastik_model = (Path(__file__).parent / "data/pixel_classifier_3D_dual_channel.ilp").as_posix()
32+
zarr_url = f"{test_data_dir_3d}/B/03/0"
1833

1934
ilastik_pixel_classification_segmentation(
2035
zarr_url=zarr_url,
2136
level=0,
22-
channel=ChannelInputModel(label="DAPI"),
23-
channel2=None,
37+
channel=ChannelInputModel(label="DAPI_2"),
38+
channel2=ChannelInputModel(label="ECadherin_2"),
2439
ilastik_model=str(ilastik_model),
2540
output_label_name="test_label",
2641
)
27-
28-
# TODO: add (2D) dual channel ilastik model for testing
29-
# TODO: add (3D) test data for dual channel testing - code works but could not yet find suitable test data
42+
43+
# Test failing of task if model was trained with two channels
44+
# but only one is provided
45+
with pytest.raises(ValueError):
46+
ilastik_pixel_classification_segmentation(
47+
zarr_url=zarr_url,
48+
level=0,
49+
channel=ChannelInputModel(label="DAPI_2"),
50+
channel2=None,
51+
ilastik_model=str(ilastik_model),
52+
output_label_name="test_label_single_channel",
53+
)

0 commit comments

Comments
 (0)