44
55import uuid
66from collections import deque
7+ from pathlib import Path
78from typing import TYPE_CHECKING
89
9- import dask
10+ import dask .array as da
11+ import dask .dataframe as dd
12+
1013# replace with the sql database once the PR in place
1114import joblib
1215import numpy as np
16+ import pandas as pd
1317import torch
1418import tqdm
15- import dask .array as da
1619from shapely .geometry import box as shapely_box
1720from shapely .strtree import STRtree
18- from torch .utils .data import DataLoader
1921from typing_extensions import Unpack
2022
23+ from tiatoolbox import DuplicateFilter , logger
2124from tiatoolbox .models .engine .semantic_segmentor import (
2225 SemanticSegmentor ,
2326 SemanticSegmentorRunParams ,
2427)
2528from tiatoolbox .tools .patchextraction import PatchExtractor
26- from tiatoolbox .models .models_abc import ModelABC
2729from tiatoolbox .utils .misc import get_tqdm
28- from .engine_abc import EngineABCRunParams
29- from tiatoolbox import DuplicateFilter , logger
30- from pathlib import Path
31-
3230
3331if TYPE_CHECKING : # pragma: no cover
3432 import os
3533 from collections .abc import Callable
3634
35+ from torch .utils .data import DataLoader
3736
3837 from tiatoolbox .annotation import AnnotationStore
38+ from tiatoolbox .models .models_abc import ModelABC
3939 from tiatoolbox .wsicore import WSIReader
4040
41+ from .engine_abc import EngineABCRunParams
4142 from .io_config import IOInstanceSegmentorConfig , IOSegmentorConfig
4243
4344
@@ -490,7 +491,9 @@ def infer_patches(
490491 labels .append (da .from_array (np .array (batch_data ["label" ])))
491492
492493 for i in range (num_expected_output ):
493- raw_predictions ["probabilities" ][i ] = da .concatenate (probabilities [i ], axis = 0 )
494+ raw_predictions ["probabilities" ][i ] = da .concatenate (
495+ probabilities [i ], axis = 0
496+ )
494497
495498 if return_coordinates :
496499 raw_predictions ["coordinates" ] = da .concatenate (coordinates , axis = 0 )
@@ -548,8 +551,8 @@ def _run_patch_mode(
548551 return_coordinates = output_type == "annotationstore" ,
549552 )
550553
551- raw_predictions [ "predictions" ] = self .post_process_patches (
552- raw_predictions = raw_predictions [ "probabilities" ] ,
554+ raw_predictions = self .post_process_patches (
555+ raw_predictions = raw_predictions ,
553556 prediction_shape = None ,
554557 prediction_dtype = None ,
555558 ** kwargs ,
@@ -570,11 +573,11 @@ def _run_patch_mode(
570573
571574 def post_process_patches ( # skipcq: PYL-R0201
572575 self : NucleusInstanceSegmentor ,
573- raw_predictions : da . Array ,
576+ raw_predictions : dict ,
574577 prediction_shape : tuple [int , ...], # noqa: ARG002
575578 prediction_dtype : type , # noqa: ARG002
576579 ** kwargs : Unpack [EngineABCRunParams ], # noqa: ARG002
577- ) -> dask . array . Array :
580+ ) -> dict :
578581 """Post-process raw patch predictions from inference.
579582
580583 This method applies a post-processing function (e.g., smoothing, filtering)
@@ -596,9 +599,44 @@ def post_process_patches( # skipcq: PYL-R0201
596599 Post-processed predictions as a Dask array.
597600
598601 """
599- raw_predictions = self .model .postproc_func (raw_predictions )
602+ probabilities = raw_predictions ["probabilities" ]
603+ predictions = [[] for _ in range (probabilities [0 ].shape [0 ])]
604+ inst_dict = [[] for _ in range (probabilities [0 ].shape [0 ])]
605+ for idx in range (probabilities [0 ].shape [0 ]):
606+ predictions [idx ], inst_dict [idx ] = self .model .postproc_func (
607+ [probabilities [0 ][idx ], probabilities [1 ][idx ], probabilities [2 ][idx ]]
608+ )
609+ inst_dict [idx ] = dd .from_pandas (pd .DataFrame (inst_dict [idx ]))
610+
611+ raw_predictions ["predictions" ] = da .stack (predictions , axis = 0 )
612+ raw_predictions ["inst_dict" ] = inst_dict
613+
600614 return raw_predictions
601615
616+ def save_predictions (
617+ self : SemanticSegmentor ,
618+ processed_predictions : dict ,
619+ output_type : str ,
620+ save_path : Path | None = None ,
621+ ** kwargs : Unpack [SemanticSegmentorRunParams ],
622+ ) -> dict | AnnotationStore | Path :
623+ """Save semantic segmentation predictions to disk or return them in memory."""
624+ # Conversion to annotationstore uses a different function for SemanticSegmentor
625+ inst_dict : list [dd .DataFrame ] | None = processed_predictions .pop (
626+ "inst_dict" , None
627+ )
628+ out = super ().save_predictions (
629+ processed_predictions , output_type , save_path = save_path , ** kwargs
630+ )
631+
632+ if isinstance (out , dict ):
633+ out ["inst_dict" ] = [[] for _ in range (len (inst_dict ))]
634+ for idx in range (len (inst_dict )):
635+ out ["inst_dict" ][idx ] = inst_dict [idx ].compute ()
636+ return out
637+
638+ return out
639+
602640 @staticmethod
603641 def _get_tile_info (
604642 image_shape : list [int ] | np .ndarray ,
0 commit comments