Skip to content

Commit ff2e1bb

Browse files
committed
✨ Add support for annotationstore
1 parent a51bab2 commit ff2e1bb

File tree

4 files changed

+104
-25
lines changed

4 files changed

+104
-25
lines changed

tests/engines/test_semantic_segmentor.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import sqlite3
77
from pathlib import Path
8-
from typing import Callable
8+
from typing import TYPE_CHECKING, Callable
99

1010
import numpy as np
1111
import torch
@@ -16,6 +16,9 @@
1616
from tiatoolbox.utils import env_detection as toolbox_env
1717
from tiatoolbox.utils.misc import imread
1818

19+
if TYPE_CHECKING:
20+
import pytest
21+
1922
device = "cuda" if toolbox_env.has_gpu() else "cpu"
2023

2124

@@ -160,7 +163,9 @@ def test_save_annotation_store(remote_sample: Callable, tmp_path: Path) -> None:
160163
_test_store_output_patch(output[0])
161164

162165

163-
def test_save_annotation_store_nparray(remote_sample: Callable, tmp_path: Path) -> None:
166+
def test_save_annotation_store_nparray(
167+
remote_sample: Callable, tmp_path: Path, caplog: pytest.LogCaptureFixture
168+
) -> None:
164169
"""Test for saving output as annotation store using a numpy array."""
165170
segmentor = SemanticSegmentor(
166171
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
@@ -184,7 +189,12 @@ def test_save_annotation_store_nparray(remote_sample: Callable, tmp_path: Path)
184189
assert output[0] == tmp_path / "output1" / "0.db"
185190
assert output[1] == tmp_path / "output1" / "1.db"
186191

187-
assert (tmp_path / "output1.zarr").exists()
192+
assert (tmp_path / "output1" / "output.zarr").exists()
193+
194+
zarr_group = zarr.open(str(tmp_path / "output1" / "output.zarr"), mode="r")
195+
assert "probabilities" in zarr_group
196+
197+
assert "Probability maps cannot be saved as AnnotationStore." in caplog.text
188198

189199
_test_store_output_patch(output[0])
190200
_test_store_output_patch(output[1])
@@ -201,6 +211,7 @@ def test_save_annotation_store_nparray(remote_sample: Callable, tmp_path: Path)
201211

202212
assert output[0] == tmp_path / "output2" / "0.db"
203213
assert output[1] == tmp_path / "output2" / "1.db"
214+
assert not (tmp_path / "output2" / "output.zarr").exists()
204215

205216
assert len(output) == 2
206217

@@ -294,7 +305,9 @@ def test_wsi_segmentor_zarr(
294305
assert 0.48 < np.mean(output_["probabilities"][:]) < 0.52
295306

296307

297-
def test_wsi_segmentor_annotationstore(sample_svs: Path, tmp_path: Path) -> None:
308+
def test_wsi_segmentor_annotationstore(
309+
sample_svs: Path, tmp_path: Path, caplog: pytest.CaptureFixture
310+
) -> None:
298311
"""Test SemanticSegmentor for WSIs with AnnotationStore output."""
299312
segmentor = SemanticSegmentor(
300313
model="fcn-tissue_mask",
@@ -314,3 +327,30 @@ def test_wsi_segmentor_annotationstore(sample_svs: Path, tmp_path: Path) -> None
314327
)
315328

316329
assert output[sample_svs] == tmp_path / "wsi_out_check" / (sample_svs.stem + ".db")
330+
331+
# Return Probabilities
332+
segmentor = SemanticSegmentor(
333+
model="fcn-tissue_mask",
334+
batch_size=32,
335+
verbose=False,
336+
)
337+
# Return Probabilities is False
338+
output = segmentor.run(
339+
images=[sample_svs],
340+
return_probabilities=True,
341+
return_labels=False,
342+
device=device,
343+
patch_mode=False,
344+
save_dir=tmp_path / "wsi_prob_out_check",
345+
verbose=True,
346+
output_type="annotationstore",
347+
)
348+
349+
assert output[sample_svs] == tmp_path / "wsi_prob_out_check" / (
350+
sample_svs.stem + ".db"
351+
)
352+
assert output[sample_svs].with_suffix(".zarr").exists()
353+
354+
zarr_group = zarr.open(output[sample_svs].with_suffix(".zarr"), mode="r")
355+
assert "probabilities" in zarr_group
356+
assert "Probability maps cannot be saved as AnnotationStore." in caplog.text

tiatoolbox/models/engine/patch_predictor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class to support patch-based and whole slide image (WSI) inference using deep le
2424

2525
from typing_extensions import Unpack
2626

27+
from tiatoolbox.utils.misc import cast_to_min_dtype
28+
2729
from .engine_abc import EngineABC, EngineABCRunParams
2830

2931
if TYPE_CHECKING: # pragma: no cover
@@ -348,7 +350,8 @@ def post_process_patches(
348350
_ = kwargs.get("return_probabilities")
349351
_ = prediction_shape
350352
_ = prediction_dtype
351-
return self.model.postproc_func(raw_predictions)
353+
raw_predictions = self.model.postproc_func(raw_predictions)
354+
return cast_to_min_dtype(raw_predictions)
352355

353356
def post_process_wsi(
354357
self: PatchPredictor,

tiatoolbox/models/engine/semantic_segmentor.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@
6666
from tiatoolbox.models.dataset.dataset_abc import WSIPatchDataset
6767
from tiatoolbox.utils.misc import (
6868
dict_to_store_semantic_segmentor,
69-
dict_to_zarr,
7069
get_tqdm,
7170
)
71+
from tiatoolbox.wsicore.wsireader import is_zarr
7272

7373
from .patch_predictor import PatchPredictor, PredictorRunParams
7474

@@ -599,12 +599,22 @@ def save_predictions(
599599
processed_predictions, output_type, save_path=save_path, **kwargs
600600
)
601601

602-
logger.info("Saving predictions as AnnotationStore.")
602+
return_probabilities = kwargs.get("return_probabilities", False)
603+
output_type_ = (
604+
"zarr"
605+
if is_zarr(save_path.with_suffix(".zarr")) or return_probabilities
606+
else "dict"
607+
)
608+
603609
processed_predictions = super().save_predictions(
604-
processed_predictions, output_type="dict", **kwargs
610+
processed_predictions,
611+
output_type=output_type_,
612+
save_path=save_path.with_suffix(".zarr"),
613+
**kwargs,
605614
)
606615

607-
return_probabilities = kwargs.get("return_probabilities", False)
616+
if isinstance(processed_predictions, Path):
617+
processed_predictions = zarr.open(str(processed_predictions), mode="r")
608618

609619
# scale_factor set from kwargs
610620
scale_factor = kwargs.get("scale_factor", (1.0, 1.0))
@@ -614,6 +624,7 @@ def save_predictions(
614624
# Need to add support for zarr conversion.
615625
save_paths = []
616626

627+
logger.info("Saving predictions as AnnotationStore.")
617628
if self.patch_mode:
618629
for i, predictions in enumerate(processed_predictions["predictions"]):
619630
if isinstance(self.images[i], Path):
@@ -639,21 +650,13 @@ def save_predictions(
639650
save_paths = out_file
640651

641652
if return_probabilities:
642-
zarr_save_path = save_path.parent.with_suffix(".zarr")
643653
msg = (
644654
f"Probability maps cannot be saved as AnnotationStore. "
645655
f"To visualise heatmaps in TIAToolbox Visualization tool,"
646-
f"convert heatmaps in {zarr_save_path} to ome.tiff using"
656+
f"convert heatmaps in {save_path} to ome.tiff using"
647657
f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff."
648658
)
649659
logger.info(msg)
650-
processed_predictions = {
651-
"predictions": processed_predictions.get("predictions"),
652-
}
653-
dict_to_zarr(
654-
raw_predictions=processed_predictions,
655-
save_path=zarr_save_path,
656-
)
657660

658661
return save_paths
659662

tiatoolbox/utils/misc.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import IO, TYPE_CHECKING
1212

1313
import cv2
14+
import dask.array as da
1415
import joblib
1516
import numcodecs
1617
import numpy as np
@@ -1369,12 +1370,11 @@ def dict_to_store_semantic_segmentor(
13691370
for each patch.
13701371
13711372
"""
1372-
preds = patch_output["predictions"]
1373+
preds = da.from_array(patch_output["predictions"], chunks="auto")
13731374

13741375
# Get the number of unique predictions
1375-
layer_list = np.unique(preds)
1376-
1377-
layer_list = np.delete(layer_list, np.where(layer_list == 0))
1376+
layer_list = da.unique(preds).compute()
1377+
layer_list = layer_list[layer_list != 0]
13781378

13791379
store = SQLiteStore()
13801380

@@ -1383,13 +1383,12 @@ def dict_to_store_semantic_segmentor(
13831383
annotations_list: list[Annotation] = []
13841384

13851385
for type_class in layer_list:
1386-
layer = np.where(preds[:] == type_class, 1, 0)
1386+
layer = da.where(preds == type_class, 1, 0).astype("uint8").compute()
13871387
contours, hierarchy = cv2.findContours(
1388-
layer.astype("uint8"),
1388+
layer,
13891389
cv2.RETR_CCOMP,
13901390
cv2.CHAIN_APPROX_NONE,
13911391
)
1392-
13931392
annotations_list_ = process_contours(contours, hierarchy, scale_factor)
13941393
annotations_list.extend(annotations_list_)
13951394

@@ -1815,3 +1814,37 @@ def get_tqdm() -> type[tqdm_notebook | tqdm]:
18151814
if is_notebook(): # pragma: no cover
18161815
return tqdm_notebook.tqdm
18171816
return tqdm
1817+
1818+
1819+
def cast_to_min_dtype(array: np.ndarray | da.Array) -> np.ndarray | da.Array:
1820+
"""Cast the input array to the minimal data type required to represent its values.
1821+
1822+
This function determines the maximum value in the array and casts it to the smallest
1823+
unsigned integer type (or boolean) that can accommodate all values. It supports both
1824+
NumPy and Dask arrays and preserves the input type in the output.
1825+
1826+
For Dask arrays, the maximum value is computed lazily and only when needed.
1827+
1828+
Args:
1829+
array (Union[np.ndarray, da.Array]): Input array containing integer values.
1830+
1831+
Returns:
1832+
(np.ndarray or da.Array):
1833+
A copy of the input array cast to the minimal required dtype.
1834+
- If the maximum value is 1, the array is cast to boolean.
1835+
- Otherwise, it is cast to the smallest suitable unsigned integer type.
1836+
1837+
"""
1838+
is_dask = isinstance(array, da.Array)
1839+
max_value = da.max(array) if is_dask else np.max(array)
1840+
max_value = max_value.compute() if is_dask else max_value
1841+
1842+
if max_value == 1:
1843+
return array.astype(bool)
1844+
1845+
dtypes = [np.uint8, np.uint16, np.uint32, np.uint64]
1846+
for dtype in dtypes:
1847+
if max_value <= np.iinfo(dtype).max:
1848+
return array.astype(dtype)
1849+
1850+
return array

0 commit comments

Comments
 (0)