Skip to content

Commit 2797ff9

Browse files
committed
✅ Test patch mode with dict output
1 parent 4bc33b7 commit 2797ff9

File tree

3 files changed

+63
-47
lines changed

3 files changed

+63
-47
lines changed

requirements/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ aiohttp>=3.8.1
44
albumentations>=1.3.0
55
bokeh>=3.1.1, <3.6.0
66
Click>=8.1.3, <8.2.0
7-
dask>=2025.10.0
7+
dask[array]>=2025.10.0
8+
dask[dataframe]>=2025.10.0
89
defusedxml>=0.7.1
910
filelock>=3.9.0
1011
flask>=2.2.2
Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
"""Test tiatoolbox.models.engine.nucleus_instance_segmentor."""
22

3-
import shutil
43
from collections.abc import Callable
54
from pathlib import Path
6-
from typing import Literal, Final
5+
from typing import Final
76

8-
import torch
97
import numpy as np
8+
import torch
109

11-
from tiatoolbox.models import IOSegmentorConfig, NucleusInstanceSegmentor
10+
from tiatoolbox.models import NucleusInstanceSegmentor
1211
from tiatoolbox.wsicore import WSIReader
1312

1413
device = "cuda:0" if torch.cuda.is_available() else "cpu"
1514

1615

17-
def test_functionality_patch_mode(remote_sample: Callable, track_tmp_path: Path) -> None:
16+
def test_functionality_patch_mode(remote_sample: Callable) -> None:
1817
"""Patch mode functionality test for nuclei instance segmentor."""
1918
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
2019
mini_wsi = WSIReader.open(mini_wsi_svs)
@@ -34,10 +33,7 @@ def test_functionality_patch_mode(remote_sample: Callable, track_tmp_path: Path)
3433
units=units,
3534
)
3635

37-
patches = np.stack(
38-
arrays=[patch1, patch2],
39-
axis=0
40-
)
36+
patches = np.stack(arrays=[patch1, patch2], axis=0)
4137

4238
inst_segmentor = NucleusInstanceSegmentor(
4339
batch_size=1,
@@ -48,29 +44,10 @@ def test_functionality_patch_mode(remote_sample: Callable, track_tmp_path: Path)
4844
images=patches,
4945
patch_mode=True,
5046
device=device,
51-
save_dir=track_tmp_path / "hovernet_fast-pannuke",
5247
output_type="dict",
5348
)
5449

55-
assert output
56-
57-
58-
def test_functionality_wsi(remote_sample: Callable, track_tmp_path: Path) -> None:
59-
"""Local functionality test for nuclei instance segmentor."""
60-
root_save_dir = Path(track_tmp_path)
61-
save_dir = Path(f"{track_tmp_path}/output")
62-
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
63-
64-
# * generate full output w/o parallel post-processing worker first
65-
shutil.rmtree(save_dir, ignore_errors=True)
66-
inst_segmentor = NucleusInstanceSegmentor(
67-
batch_size=8,
68-
num_postproc_workers=0,
69-
pretrained_model="hovernet_fast-pannuke",
70-
)
71-
output = inst_segmentor.run(
72-
[mini_wsi_svs],
73-
patch_mode=False,
74-
device=device,
75-
save_dir=save_dir,
76-
)
50+
assert np.max(output["predictions"][0][:]) == 41
51+
assert np.max(output["predictions"][1][:]) == 17
52+
assert len(output["inst_dict"][0].columns) == 41
53+
assert len(output["inst_dict"][1].columns) == 17

tiatoolbox/models/engine/nucleus_instance_segmentor.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,41 @@
44

55
import uuid
66
from collections import deque
7+
from pathlib import Path
78
from typing import TYPE_CHECKING
89

9-
import dask
10+
import dask.array as da
11+
import dask.dataframe as dd
12+
1013
# replace with the sql database once the PR in place
1114
import joblib
1215
import numpy as np
16+
import pandas as pd
1317
import torch
1418
import tqdm
15-
import dask.array as da
1619
from shapely.geometry import box as shapely_box
1720
from shapely.strtree import STRtree
18-
from torch.utils.data import DataLoader
1921
from typing_extensions import Unpack
2022

23+
from tiatoolbox import DuplicateFilter, logger
2124
from tiatoolbox.models.engine.semantic_segmentor import (
2225
SemanticSegmentor,
2326
SemanticSegmentorRunParams,
2427
)
2528
from tiatoolbox.tools.patchextraction import PatchExtractor
26-
from tiatoolbox.models.models_abc import ModelABC
2729
from tiatoolbox.utils.misc import get_tqdm
28-
from .engine_abc import EngineABCRunParams
29-
from tiatoolbox import DuplicateFilter, logger
30-
from pathlib import Path
31-
3230

3331
if TYPE_CHECKING: # pragma: no cover
3432
import os
3533
from collections.abc import Callable
3634

35+
from torch.utils.data import DataLoader
3736

3837
from tiatoolbox.annotation import AnnotationStore
38+
from tiatoolbox.models.models_abc import ModelABC
3939
from tiatoolbox.wsicore import WSIReader
4040

41+
from .engine_abc import EngineABCRunParams
4142
from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig
4243

4344

@@ -490,7 +491,9 @@ def infer_patches(
490491
labels.append(da.from_array(np.array(batch_data["label"])))
491492

492493
for i in range(num_expected_output):
493-
raw_predictions["probabilities"][i] = da.concatenate(probabilities[i], axis=0)
494+
raw_predictions["probabilities"][i] = da.concatenate(
495+
probabilities[i], axis=0
496+
)
494497

495498
if return_coordinates:
496499
raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0)
@@ -548,8 +551,8 @@ def _run_patch_mode(
548551
return_coordinates=output_type == "annotationstore",
549552
)
550553

551-
raw_predictions["predictions"] = self.post_process_patches(
552-
raw_predictions=raw_predictions["probabilities"],
554+
raw_predictions = self.post_process_patches(
555+
raw_predictions=raw_predictions,
553556
prediction_shape=None,
554557
prediction_dtype=None,
555558
**kwargs,
@@ -570,11 +573,11 @@ def _run_patch_mode(
570573

571574
def post_process_patches( # skipcq: PYL-R0201
572575
self: NucleusInstanceSegmentor,
573-
raw_predictions: da.Array,
576+
raw_predictions: dict,
574577
prediction_shape: tuple[int, ...], # noqa: ARG002
575578
prediction_dtype: type, # noqa: ARG002
576579
**kwargs: Unpack[EngineABCRunParams], # noqa: ARG002
577-
) -> dask.array.Array:
580+
) -> dict:
578581
"""Post-process raw patch predictions from inference.
579582
580583
This method applies a post-processing function (e.g., smoothing, filtering)
@@ -596,9 +599,44 @@ def post_process_patches( # skipcq: PYL-R0201
596599
Post-processed predictions as a Dask array.
597600
598601
"""
599-
raw_predictions = self.model.postproc_func(raw_predictions)
602+
probabilities = raw_predictions["probabilities"]
603+
predictions = [[] for _ in range(probabilities[0].shape[0])]
604+
inst_dict = [[] for _ in range(probabilities[0].shape[0])]
605+
for idx in range(probabilities[0].shape[0]):
606+
predictions[idx], inst_dict[idx] = self.model.postproc_func(
607+
[probabilities[0][idx], probabilities[1][idx], probabilities[2][idx]]
608+
)
609+
inst_dict[idx] = dd.from_pandas(pd.DataFrame(inst_dict[idx]))
610+
611+
raw_predictions["predictions"] = da.stack(predictions, axis=0)
612+
raw_predictions["inst_dict"] = inst_dict
613+
600614
return raw_predictions
601615

616+
def save_predictions(
617+
self: SemanticSegmentor,
618+
processed_predictions: dict,
619+
output_type: str,
620+
save_path: Path | None = None,
621+
**kwargs: Unpack[SemanticSegmentorRunParams],
622+
) -> dict | AnnotationStore | Path:
623+
"""Save semantic segmentation predictions to disk or return them in memory."""
624+
# Conversion to annotationstore uses a different function for SemanticSegmentor
625+
inst_dict: list[dd.DataFrame] | None = processed_predictions.pop(
626+
"inst_dict", None
627+
)
628+
out = super().save_predictions(
629+
processed_predictions, output_type, save_path=save_path, **kwargs
630+
)
631+
632+
if isinstance(out, dict):
633+
out["inst_dict"] = [[] for _ in range(len(inst_dict))]
634+
for idx in range(len(inst_dict)):
635+
out["inst_dict"][idx] = inst_dict[idx].compute()
636+
return out
637+
638+
return out
639+
602640
@staticmethod
603641
def _get_tile_info(
604642
image_shape: list[int] | np.ndarray,

0 commit comments

Comments
 (0)