Skip to content

Commit da6a1ea

Browse files
committed
✅ Test patch mode with dict and zarr output
1 parent 2797ff9 commit da6a1ea

File tree

4 files changed

+149
-31
lines changed

4 files changed

+149
-31
lines changed

tests/engines/test_nucleus_instance_segmentor.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66

77
import numpy as np
88
import torch
9+
import zarr
910

1011
from tiatoolbox.models import NucleusInstanceSegmentor
1112
from tiatoolbox.wsicore import WSIReader
1213

1314
device = "cuda:0" if torch.cuda.is_available() else "cpu"
1415

1516

16-
def test_functionality_patch_mode(remote_sample: Callable) -> None:
17+
def test_functionality_patch_mode(
18+
remote_sample: Callable, track_tmp_path: Path
19+
) -> None:
1720
"""Patch mode functionality test for nuclei instance segmentor."""
1821
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
1922
mini_wsi = WSIReader.open(mini_wsi_svs)
@@ -33,7 +36,10 @@ def test_functionality_patch_mode(remote_sample: Callable) -> None:
3336
units=units,
3437
)
3538

36-
patches = np.stack(arrays=[patch1, patch2], axis=0)
39+
# Test dummy input, should result in no output segmentation
40+
patch3 = np.zeros_like(patch1)
41+
42+
patches = np.stack(arrays=[patch1, patch2, patch3], axis=0)
3743

3844
inst_segmentor = NucleusInstanceSegmentor(
3945
batch_size=1,
@@ -49,5 +55,85 @@ def test_functionality_patch_mode(remote_sample: Callable) -> None:
4955

5056
assert np.max(output["predictions"][0][:]) == 41
5157
assert np.max(output["predictions"][1][:]) == 17
52-
assert len(output["inst_dict"][0].columns) == 41
53-
assert len(output["inst_dict"][1].columns) == 17
58+
assert np.max(output["predictions"][2][:]) == 0
59+
60+
assert len(output["box"][0]) == 41
61+
assert len(output["box"][1]) == 17
62+
assert len(output["box"][2]) == 0
63+
64+
assert len(output["centroid"][0]) == 41
65+
assert len(output["centroid"][1]) == 17
66+
assert len(output["centroid"][2]) == 0
67+
68+
assert len(output["contour"][0]) == 41
69+
assert len(output["contour"][1]) == 17
70+
assert len(output["contour"][2]) == 0
71+
72+
assert len(output["prob"][0]) == 41
73+
assert len(output["prob"][1]) == 17
74+
assert len(output["prob"][2]) == 0
75+
76+
assert len(output["type"][0]) == 41
77+
assert len(output["type"][1]) == 17
78+
assert len(output["type"][2]) == 0
79+
80+
output_ = output
81+
82+
output = inst_segmentor.run(
83+
images=patches,
84+
patch_mode=True,
85+
device=device,
86+
output_type="zarr",
87+
save_dir=track_tmp_path / "patch_output_zarr",
88+
)
89+
90+
output = zarr.open(output, mode="r")
91+
92+
assert np.max(output["predictions"][0][:]) == 41
93+
assert np.max(output["predictions"][1][:]) == 17
94+
95+
assert all(
96+
np.array_equal(a, b)
97+
for a, b in zip(output["box"][0], output_["box"][0], strict=False)
98+
)
99+
assert all(
100+
np.array_equal(a, b)
101+
for a, b in zip(output["box"][1], output_["box"][1], strict=False)
102+
)
103+
assert len(output["box"][2]) == 0
104+
105+
assert all(
106+
np.array_equal(a, b)
107+
for a, b in zip(output["centroid"][0], output_["centroid"][0], strict=False)
108+
)
109+
assert all(
110+
np.array_equal(a, b)
111+
for a, b in zip(output["centroid"][1], output_["centroid"][1], strict=False)
112+
)
113+
114+
assert all(
115+
np.array_equal(a, b)
116+
for a, b in zip(output["contour"][0], output_["contour"][0], strict=False)
117+
)
118+
assert all(
119+
np.array_equal(a, b)
120+
for a, b in zip(output["contour"][1], output_["contour"][1], strict=False)
121+
)
122+
123+
assert all(
124+
np.array_equal(a, b)
125+
for a, b in zip(output["prob"][0], output_["prob"][0], strict=False)
126+
)
127+
assert all(
128+
np.array_equal(a, b)
129+
for a, b in zip(output["prob"][1], output_["prob"][1], strict=False)
130+
)
131+
132+
assert all(
133+
np.array_equal(a, b)
134+
for a, b in zip(output["type"][0], output_["type"][0], strict=False)
135+
)
136+
assert all(
137+
np.array_equal(a, b)
138+
for a, b in zip(output["type"][1], output_["type"][1], strict=False)
139+
)

tiatoolbox/models/architecture/hovernet.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44

55
import math
66
from collections import OrderedDict
7-
import dask
7+
88
import cv2
9+
import dask
10+
import dask.array as da
11+
import dask.dataframe as dd
912
import numpy as np
13+
import pandas as pd
1014
import torch
1115
import torch.nn.functional as F # noqa: N812
1216
from scipy import ndimage
@@ -22,6 +26,8 @@
2226
from tiatoolbox.models.models_abc import ModelABC
2327
from tiatoolbox.utils.misc import get_bounding_box
2428

29+
dask.config.set({"dataframe.convert-string": False})
30+
2531

2632
class TFSamepaddingLayer(nn.Module):
2733
"""To align with tensorflow `same` padding.
@@ -782,7 +788,28 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]:
782788
pred_inst = HoVerNet._proc_np_hv(np_map, hv_map)
783789
nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type)
784790

785-
return pred_inst, nuc_inst_info_dict
791+
if not nuc_inst_info_dict:
792+
nuc_inst_info_dict = { # inst_id should start at 1
793+
"box": da.empty(shape=0),
794+
"centroid": da.empty(shape=0),
795+
"contour": da.empty(shape=0),
796+
"prob": da.empty(shape=0),
797+
"type": da.empty(shape=0),
798+
}
799+
return pred_inst, nuc_inst_info_dict
800+
801+
# dask dataframe does not support transpose
802+
nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose()
803+
804+
# create dask dataframe
805+
nuc_inst_info_dd = dd.from_pandas(nuc_inst_info_df)
806+
807+
# reinitialize nuc_inst_info_dict
808+
nuc_inst_info_dict_ = {}
809+
for key in nuc_inst_info_df.columns:
810+
nuc_inst_info_dict_[key] = nuc_inst_info_dd[key].to_dask_array(lengths=True)
811+
812+
return pred_inst, nuc_inst_info_dict_
786813

787814
@staticmethod
788815
def infer_batch( # skipcq: PYL-W0221

tiatoolbox/models/engine/engine_abc.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import zarr
4747
from dask import compute
4848
from dask.diagnostics import ProgressBar
49+
from numcodecs import Pickle
4950
from torch import nn
5051
from typing_extensions import Unpack
5152

@@ -71,6 +72,8 @@
7172
from tiatoolbox.models.models_abc import ModelABC
7273
from tiatoolbox.type_hints import IntPair, Resolution, Units
7374

75+
dask.config.set({"dataframe.convert-string": False})
76+
7477

7578
class EngineABCRunParams(TypedDict, total=False):
7679
"""Parameters for configuring the :func:`EngineABC.run()` method.
@@ -645,13 +648,29 @@ def save_predictions(
645648
keys_to_compute = [k for k in keys_to_compute if k not in zarr_group]
646649
write_tasks = []
647650
for key in keys_to_compute:
648-
dask_array = processed_predictions[key].rechunk("auto")
649-
task = dask_array.to_zarr(
650-
url=save_path,
651-
component=key,
652-
compute=False,
653-
)
654-
write_tasks.append(task)
651+
dask_output = processed_predictions[key]
652+
if isinstance(dask_output, da.Array):
653+
dask_output = dask_output.rechunk("auto")
654+
task = dask_output.to_zarr(
655+
url=save_path, component=key, compute=False, object_codec=None
656+
)
657+
write_tasks.append(task)
658+
659+
if isinstance(dask_output, list) and all(
660+
isinstance(dask_array, da.Array) for dask_array in dask_output
661+
):
662+
for i, dask_array in enumerate(dask_output):
663+
object_codec = (
664+
Pickle() if dask_array.dtype == "object" else None
665+
)
666+
task = dask_array.to_zarr(
667+
url=save_path,
668+
component=f"{key}/{i}",
669+
compute=False,
670+
object_codec=object_codec,
671+
)
672+
write_tasks.append(task)
673+
655674
msg = f"Saving output to {save_path}."
656675
logger.info(msg=msg)
657676
with ProgressBar():

tiatoolbox/models/engine/nucleus_instance_segmentor.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
from typing import TYPE_CHECKING
99

1010
import dask.array as da
11-
import dask.dataframe as dd
1211

1312
# replace with the sql database once the PR in place
1413
import joblib
1514
import numpy as np
16-
import pandas as pd
1715
import torch
1816
import tqdm
1917
from shapely.geometry import box as shapely_box
@@ -601,15 +599,15 @@ def post_process_patches( # skipcq: PYL-R0201
601599
"""
602600
probabilities = raw_predictions["probabilities"]
603601
predictions = [[] for _ in range(probabilities[0].shape[0])]
604-
inst_dict = [[] for _ in range(probabilities[0].shape[0])]
602+
inst_dict = [[{}] for _ in range(probabilities[0].shape[0])]
605603
for idx in range(probabilities[0].shape[0]):
606604
predictions[idx], inst_dict[idx] = self.model.postproc_func(
607605
[probabilities[0][idx], probabilities[1][idx], probabilities[2][idx]]
608606
)
609-
inst_dict[idx] = dd.from_pandas(pd.DataFrame(inst_dict[idx]))
610607

611608
raw_predictions["predictions"] = da.stack(predictions, axis=0)
612-
raw_predictions["inst_dict"] = inst_dict
609+
for key in inst_dict[0]:
610+
raw_predictions[key] = [d[key] for d in inst_dict]
613611

614612
return raw_predictions
615613

@@ -621,22 +619,10 @@ def save_predictions(
621619
**kwargs: Unpack[SemanticSegmentorRunParams],
622620
) -> dict | AnnotationStore | Path:
623621
"""Save semantic segmentation predictions to disk or return them in memory."""
624-
# Conversion to annotationstore uses a different function for SemanticSegmentor
625-
inst_dict: list[dd.DataFrame] | None = processed_predictions.pop(
626-
"inst_dict", None
627-
)
628-
out = super().save_predictions(
622+
return super().save_predictions(
629623
processed_predictions, output_type, save_path=save_path, **kwargs
630624
)
631625

632-
if isinstance(out, dict):
633-
out["inst_dict"] = [[] for _ in range(len(inst_dict))]
634-
for idx in range(len(inst_dict)):
635-
out["inst_dict"][idx] = inst_dict[idx].compute()
636-
return out
637-
638-
return out
639-
640626
@staticmethod
641627
def _get_tile_info(
642628
image_shape: list[int] | np.ndarray,

0 commit comments

Comments
 (0)