Skip to content

Commit e6dc905

Browse files
committed
🧪 Add failing test for annotationstore conversion
1 parent caece3f commit e6dc905

File tree

2 files changed

+115
-8
lines changed

2 files changed

+115
-8
lines changed

tiatoolbox/models/engine/nucleus_instance_segmentor.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,21 @@
1414
import numpy as np
1515
import torch
1616
import tqdm
17+
import zarr
1718
from shapely.geometry import box as shapely_box
19+
from shapely.geometry import shape as feature2geometry
1820
from shapely.strtree import STRtree
1921
from typing_extensions import Unpack
2022

2123
from tiatoolbox import DuplicateFilter, logger
24+
from tiatoolbox.annotation.storage import Annotation
2225
from tiatoolbox.models.engine.semantic_segmentor import (
2326
SemanticSegmentor,
2427
SemanticSegmentorRunParams,
2528
)
2629
from tiatoolbox.tools.patchextraction import PatchExtractor
27-
from tiatoolbox.utils.misc import get_tqdm
30+
from tiatoolbox.utils.misc import get_tqdm, make_valid_poly
31+
from tiatoolbox.wsicore.wsireader import is_zarr
2832

2933
if TYPE_CHECKING: # pragma: no cover
3034
import os
@@ -613,17 +617,85 @@ def post_process_patches( # skipcq: PYL-R0201
613617
return raw_predictions
614618

615619
def save_predictions(
616-
self: SemanticSegmentor,
620+
self: NucleusInstanceSegmentor,
617621
processed_predictions: dict,
618622
output_type: str,
619623
save_path: Path | None = None,
620624
**kwargs: Unpack[SemanticSegmentorRunParams],
621-
) -> dict | AnnotationStore | Path:
625+
) -> dict | AnnotationStore | Path | list[Path]:
622626
"""Save semantic segmentation predictions to disk or return them in memory."""
623-
return super().save_predictions(
624-
processed_predictions, output_type, save_path=save_path, **kwargs
627+
# Conversion to annotationstore uses a different function
628+
# for NucleusInstanceSegmentor.
629+
if output_type.lower() != "annotationstore":
630+
return super().save_predictions(
631+
processed_predictions, output_type, save_path=save_path, **kwargs
632+
)
633+
634+
return_probabilities = kwargs.get("return_probabilities", False)
635+
output_type_ = (
636+
"zarr"
637+
if is_zarr(save_path.with_suffix(".zarr")) or return_probabilities
638+
else "dict"
639+
)
640+
641+
# This runs dask.compute and returns numpy arrays
642+
# for saving annotationstore output.
643+
processed_predictions = super().save_predictions(
644+
processed_predictions,
645+
output_type=output_type_,
646+
save_path=save_path.with_suffix(".zarr"),
647+
**kwargs,
625648
)
626649

650+
if isinstance(processed_predictions, Path):
651+
processed_predictions = zarr.open(str(processed_predictions), mode="r")
652+
653+
# scale_factor set from kwargs
654+
scale_factor = kwargs.get("scale_factor", (1.0, 1.0))
655+
# class_dict set from kwargs
656+
class_dict = kwargs.get("class_dict")
657+
658+
# Need to add support for zarr conversion.
659+
save_paths = []
660+
661+
logger.info("Saving predictions as AnnotationStore.")
662+
663+
# Not required for annotationstore
664+
processed_predictions.pop("predictions")
665+
if self.patch_mode:
666+
for i, predictions in enumerate(
667+
zip(*processed_predictions.values(), strict=False)
668+
):
669+
predictions_ = dict(
670+
zip(processed_predictions.keys(), predictions, strict=False)
671+
)
672+
if isinstance(self.images[i], Path):
673+
output_path = save_path.parent / (self.images[i].stem + ".db")
674+
else:
675+
output_path = save_path.parent / (str(i) + ".db")
676+
677+
origin = predictions_.pop("coordinates")[:2]
678+
679+
out_file = dict_to_store(
680+
processed_predictions=predictions_,
681+
class_dict=class_dict,
682+
scale_factor=scale_factor,
683+
origin=origin,
684+
)
685+
686+
save_paths.append(out_file)
687+
688+
if return_probabilities:
689+
msg = (
690+
f"Probability maps cannot be saved as AnnotationStore. "
691+
f"To visualise heatmaps in TIAToolbox Visualization tool,"
692+
f"convert heatmaps in {save_path} to ome.tiff using"
693+
f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff."
694+
)
695+
logger.info(msg)
696+
697+
return save_paths
698+
627699
@staticmethod
628700
def _get_tile_info(
629701
image_shape: list[int] | np.ndarray,
@@ -1057,3 +1129,38 @@ def run(
10571129
output_type=output_type,
10581130
**kwargs,
10591131
)
1132+
1133+
1134+
def dict_to_store(
1135+
processed_predictions: dict,
1136+
class_dict: dict | None = None,
1137+
origin: tuple[float, float] = (0, 0),
1138+
scale_factor: tuple[float, float] = (1, 1),
1139+
) -> list[Annotation]:
1140+
"""Helper function to convert dict to store."""
1141+
contour = processed_predictions.pop("contour")
1142+
1143+
ann = []
1144+
for i, contour_ in enumerate(contour):
1145+
ann_ = Annotation(
1146+
make_valid_poly(
1147+
feature2geometry(
1148+
{
1149+
"type": processed_predictions.get("geom_type", "Polygon"),
1150+
"coordinates": scale_factor * np.array(contour_),
1151+
},
1152+
),
1153+
origin,
1154+
),
1155+
{
1156+
prop: (
1157+
class_dict[processed_predictions[prop]][i]
1158+
if prop == "type" and class_dict is not None
1159+
else processed_predictions[prop]
1160+
)
1161+
for prop in processed_predictions
1162+
},
1163+
)
1164+
ann.append(ann_)
1165+
1166+
return ann

tiatoolbox/models/engine/semantic_segmentor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def save_predictions(
558558
output_type: str,
559559
save_path: Path | None = None,
560560
**kwargs: Unpack[SemanticSegmentorRunParams],
561-
) -> dict | AnnotationStore | Path:
561+
) -> dict | AnnotationStore | Path | list[Path]:
562562
"""Save semantic segmentation predictions to disk or return them in memory.
563563
564564
This method saves predictions in one of the supported formats:
@@ -583,11 +583,11 @@ def save_predictions(
583583
- return_probabilities (bool): Whether to save probability maps.
584584
585585
Returns:
586-
dict | AnnotationStore | Path:
586+
dict | AnnotationStore | Path | list[Path]:
587587
- If output_type is "dict": returns predictions as a dictionary.
588588
- If output_type is "zarr": returns path to saved Zarr file.
589589
- If output_type is "annotationstore": returns AnnotationStore
590-
or path to .db file.
590+
or path or list of paths to .db file.
591591
592592
"""
593593
# Conversion to annotationstore uses a different function for SemanticSegmentor

0 commit comments

Comments
 (0)