|
14 | 14 | import numpy as np |
15 | 15 | import torch |
16 | 16 | import tqdm |
| 17 | +import zarr |
17 | 18 | from shapely.geometry import box as shapely_box |
| 19 | +from shapely.geometry import shape as feature2geometry |
18 | 20 | from shapely.strtree import STRtree |
19 | 21 | from typing_extensions import Unpack |
20 | 22 |
|
21 | 23 | from tiatoolbox import DuplicateFilter, logger |
| 24 | +from tiatoolbox.annotation.storage import Annotation |
22 | 25 | from tiatoolbox.models.engine.semantic_segmentor import ( |
23 | 26 | SemanticSegmentor, |
24 | 27 | SemanticSegmentorRunParams, |
25 | 28 | ) |
26 | 29 | from tiatoolbox.tools.patchextraction import PatchExtractor |
27 | | -from tiatoolbox.utils.misc import get_tqdm |
| 30 | +from tiatoolbox.utils.misc import get_tqdm, make_valid_poly |
| 31 | +from tiatoolbox.wsicore.wsireader import is_zarr |
28 | 32 |
|
29 | 33 | if TYPE_CHECKING: # pragma: no cover |
30 | 34 | import os |
@@ -613,17 +617,85 @@ def post_process_patches( # skipcq: PYL-R0201 |
613 | 617 | return raw_predictions |
614 | 618 |
|
615 | 619 | def save_predictions( |
616 | | - self: SemanticSegmentor, |
| 620 | + self: NucleusInstanceSegmentor, |
617 | 621 | processed_predictions: dict, |
618 | 622 | output_type: str, |
619 | 623 | save_path: Path | None = None, |
620 | 624 | **kwargs: Unpack[SemanticSegmentorRunParams], |
621 | | - ) -> dict | AnnotationStore | Path: |
| 625 | + ) -> dict | AnnotationStore | Path | list[Path]: |
622 | 626 | """Save semantic segmentation predictions to disk or return them in memory.""" |
623 | | - return super().save_predictions( |
624 | | - processed_predictions, output_type, save_path=save_path, **kwargs |
| 627 | + # Conversion to annotationstore uses a different function |
| 628 | + # for NucleusInstanceSegmentor. |
| 629 | + if output_type.lower() != "annotationstore": |
| 630 | + return super().save_predictions( |
| 631 | + processed_predictions, output_type, save_path=save_path, **kwargs |
| 632 | + ) |
| 633 | + |
| 634 | + return_probabilities = kwargs.get("return_probabilities", False) |
| 635 | + output_type_ = ( |
| 636 | + "zarr" |
| 637 | + if is_zarr(save_path.with_suffix(".zarr")) or return_probabilities |
| 638 | + else "dict" |
| 639 | + ) |
| 640 | + |
| 641 | + # This runs dask.compute and returns numpy arrays |
| 642 | + # for saving annotationstore output. |
| 643 | + processed_predictions = super().save_predictions( |
| 644 | + processed_predictions, |
| 645 | + output_type=output_type_, |
| 646 | + save_path=save_path.with_suffix(".zarr"), |
| 647 | + **kwargs, |
625 | 648 | ) |
626 | 649 |
|
| 650 | + if isinstance(processed_predictions, Path): |
| 651 | + processed_predictions = zarr.open(str(processed_predictions), mode="r") |
| 652 | + |
| 653 | + # scale_factor set from kwargs |
| 654 | + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) |
| 655 | + # class_dict set from kwargs |
| 656 | + class_dict = kwargs.get("class_dict") |
| 657 | + |
| 658 | + # Need to add support for zarr conversion. |
| 659 | + save_paths = [] |
| 660 | + |
| 661 | + logger.info("Saving predictions as AnnotationStore.") |
| 662 | + |
| 663 | + # Not required for annotationstore |
| 664 | + processed_predictions.pop("predictions") |
| 665 | + if self.patch_mode: |
| 666 | + for i, predictions in enumerate( |
| 667 | + zip(*processed_predictions.values(), strict=False) |
| 668 | + ): |
| 669 | + predictions_ = dict( |
| 670 | + zip(processed_predictions.keys(), predictions, strict=False) |
| 671 | + ) |
| 672 | + if isinstance(self.images[i], Path): |
| 673 | + output_path = save_path.parent / (self.images[i].stem + ".db") |
| 674 | + else: |
| 675 | + output_path = save_path.parent / (str(i) + ".db") |
| 676 | + |
| 677 | + origin = predictions_.pop("coordinates")[:2] |
| 678 | + |
| 679 | + out_file = dict_to_store( |
| 680 | + processed_predictions=predictions_, |
| 681 | + class_dict=class_dict, |
| 682 | + scale_factor=scale_factor, |
| 683 | + origin=origin, |
| 684 | + ) |
| 685 | + |
| 686 | + save_paths.append(out_file) |
| 687 | + |
| 688 | + if return_probabilities: |
| 689 | + msg = ( |
| 690 | + f"Probability maps cannot be saved as AnnotationStore. " |
| 691 | + f"To visualise heatmaps in TIAToolbox Visualization tool," |
| 692 | + f"convert heatmaps in {save_path} to ome.tiff using" |
| 693 | + f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." |
| 694 | + ) |
| 695 | + logger.info(msg) |
| 696 | + |
| 697 | + return save_paths |
| 698 | + |
627 | 699 | @staticmethod |
628 | 700 | def _get_tile_info( |
629 | 701 | image_shape: list[int] | np.ndarray, |
@@ -1057,3 +1129,38 @@ def run( |
1057 | 1129 | output_type=output_type, |
1058 | 1130 | **kwargs, |
1059 | 1131 | ) |
| 1132 | + |
| 1133 | + |
| 1134 | +def dict_to_store( |
| 1135 | + processed_predictions: dict, |
| 1136 | + class_dict: dict | None = None, |
| 1137 | + origin: tuple[float, float] = (0, 0), |
| 1138 | + scale_factor: tuple[float, float] = (1, 1), |
| 1139 | +) -> list[Annotation]: |
| 1140 | + """Helper function to convert dict to store.""" |
| 1141 | + contour = processed_predictions.pop("contour") |
| 1142 | + |
| 1143 | + ann = [] |
| 1144 | + for i, contour_ in enumerate(contour): |
| 1145 | + ann_ = Annotation( |
| 1146 | + make_valid_poly( |
| 1147 | + feature2geometry( |
| 1148 | + { |
| 1149 | + "type": processed_predictions.get("geom_type", "Polygon"), |
| 1150 | + "coordinates": scale_factor * np.array(contour_), |
| 1151 | + }, |
| 1152 | + ), |
| 1153 | + origin, |
| 1154 | + ), |
| 1155 | + { |
| 1156 | + prop: ( |
| 1157 | + class_dict[processed_predictions[prop]][i] |
| 1158 | + if prop == "type" and class_dict is not None |
| 1159 | + else processed_predictions[prop] |
| 1160 | + ) |
| 1161 | + for prop in processed_predictions |
| 1162 | + }, |
| 1163 | + ) |
| 1164 | + ann.append(ann_) |
| 1165 | + |
| 1166 | + return ann |
0 commit comments