1- from typing import Dict , List
1+ from typing import Callable , Dict , List
22
33import numpy as np
44from pathos .multiprocessing import ThreadPool as Pool
5- from skimage .filters import rank
6- from skimage .morphology import closing , disk , opening
75from skimage .util import img_as_ubyte
86from tqdm import tqdm
97
108from ..postproc import POSTPROC_LOOKUP
11- from ..utils import binarize , fill_holes_semantic , remove_debris_semantic
9+ from ..utils import (
10+ fill_holes_semantic ,
11+ majority_vote_parallel ,
12+ majority_vote_sequential ,
13+ med_filt_parallel ,
14+ med_filt_sequential ,
15+ remove_debris_semantic ,
16+ )
1217
1318__all__ = ["PostProcessor" ]
1419
2833
2934class PostProcessor :
3035 def __init__ (
31- self , instance_postproc : str , inst_key : str , aux_key : str , ** kwargs
36+ self ,
37+ instance_postproc : str ,
38+ inst_key : str ,
39+ aux_key : str ,
40+ type_post_proc : Callable = None ,
41+ sem_post_proc : Callable = None ,
42+ ** kwargs ,
3243 ) -> None :
3344 """Multi-threaded post-processing.
3445
@@ -42,6 +53,12 @@ def __init__(
4253 aux_key : Tuple[str, ...]:
4354 The key/name of the model auxilliary output that is used for the
4455 instance segmentation post-processing pipeline.
56+ type_post_proc : Callable, optional
57+ A post-processing function for the type maps. If not None, overrides
58+ the default.
59+ sem_post_proc : Callable, optional
60+ A post-processing function for the semantc seg maps. If not None,
61+ overrides the default.
4562 **kwargs
4663 Arbitrary keyword arguments that can be used for any of the private
4764 post-processing functions of this class.
@@ -57,36 +74,32 @@ def __init__(
5774 self .inst_key = inst_key
5875 self .aux_key = aux_key
5976 self .kwargs = kwargs
77+ self .sem_post_proc = sem_post_proc
78+ self .type_post_proc = type_post_proc
6079
6180 def _get_sem_map (
6281 self ,
6382 prob_map : np .ndarray ,
64- use_blur : bool = False ,
65- use_closing : bool = False ,
66- use_opening : bool = True ,
67- disk_size : int = 10 ,
83+ parallel : bool = False ,
84+ kernel_width : int = 15 ,
6885 ** kwargs ,
6986 ) -> np .ndarray :
7087 """Run the semantic segmentation post-processing."""
71- # Median filtering to get rid of noise. Adds a lot of overhead sop optional.
72- if use_blur :
73- sem = np .zeros_like (prob_map )
74- for i in range (prob_map .shape [0 ]):
75- sem [i , ...] = rank .median (
76- img_as_ubyte (prob_map [i , ...]), footprint = disk (disk_size )
77- )
78- prob_map = sem
79-
80- sem = np .argmax (prob_map , axis = 0 )
88+ sem_map = img_as_ubyte (prob_map )
8189
82- if use_opening :
83- sem = opening (sem , disk (disk_size ))
84-
85- if use_closing :
86- sem = closing (sem , disk (disk_size ))
90+ if self .sem_post_proc is not None :
91+ sem = self .sem_post_proc (sem_map )
92+ else :
93+ if parallel :
94+ sem = med_filt_parallel (
95+ sem_map , kernel_size = (kernel_width , kernel_width )
96+ )
97+ else :
98+ sem = med_filt_sequential (sem_map , kernel_width )
8799
88- sem = remove_debris_semantic (sem )
89- sem = fill_holes_semantic (sem )
100+ sem = np .argmax (sem , axis = 0 )
101+ sem = remove_debris_semantic (sem )
102+ sem = fill_holes_semantic (sem )
90103
91104 return sem
92105
@@ -105,31 +118,19 @@ def _get_type_map(
105118 self ,
106119 prob_map : np .ndarray ,
107120 inst_map : np .ndarray ,
108- use_mask : bool = False ,
121+ parallel : bool = True ,
109122 ** kwargs ,
110123 ) -> np .ndarray :
111- """Run the type map post-processing. Majority voting for each instance.
112-
113- Adapted from:
114- https://github.com/vqdang/hover_net/blob/master/models/hovernet/post_proc.py
115- """
124+ """Run the type map post-processing. Majority voting for each instance."""
116125 type_map = np .argmax (prob_map , axis = 0 )
117- if use_mask :
118- type_map = binarize (inst_map ) * type_map
119-
120- pred_id_list = np .unique (inst_map )[1 :]
121- for inst_id in pred_id_list :
122- inst_type = type_map [inst_map == inst_id ]
123- type_list , type_pixels = np .unique (inst_type , return_counts = True )
124- type_list = list (zip (type_list , type_pixels ))
125- type_list = sorted (type_list , key = lambda x : x [1 ], reverse = True )
126- cell_type = type_list [0 ][0 ]
127-
128- if cell_type == 0 :
129- if len (type_list ) > 1 :
130- cell_type = type_list [1 ][0 ]
131126
132- type_map [inst_map == inst_id ] = cell_type
127+ if self .type_post_proc is not None :
128+ type_map = self .type_post_proc (type_map , inst_map , ** kwargs )
129+ else :
130+ if parallel :
131+ type_map = majority_vote_parallel (type_map , inst_map )
132+ else :
133+ type_map = majority_vote_sequential (type_map , inst_map )
133134
134135 return type_map
135136
@@ -175,7 +176,6 @@ def run_parallel(
175176 progress_bar : bool, default=False
176177 If True, a tqdm progress bar is shown.
177178
178-
179179 Returns
180180 -------
181181 List[Dict[str, np.ndarray]]:
0 commit comments