1- """Imported most of the stuff from stardist repo. Minor modifications.
1+ """Copied the polygons to label utilities from stardist repo (with minor modifications) .
22
33BSD 3-Clause License
44
3434from typing import Tuple
3535
3636import numpy as np
37- import scipy .ndimage as ndi
38- from skimage import img_as_ubyte
3937from skimage .draw import polygon
40- from skimage .measure import regionprops
38+ from skimage .morphology import disk , erosion
4139
42- from ...utils import bounding_box , remap_label , remove_small_objects
43- from .drfns import find_maxima , h_minima_reconstruction
40+ from .nms import get_bboxes , nms_stardist
4441
45- __all__ = ["post_proc_stardist" , " post_proc_stardist_orig" , "polygons_to_label" ]
42+ __all__ = ["post_proc_stardist_orig" , "polygons_to_label" ]
4643
4744
4845def polygons_to_label_coord (
@@ -191,42 +188,25 @@ def polygons_to_label(
191188 return polygons_to_label_coord (coord , shape = shape , labels = ind )
192189
193190
194- def _clean_up (inst_map : np .ndarray , size : int = 150 , ** kwargs ) -> np .ndarray :
195- """Clean up overlapping instances."""
196- mask = remap_label (inst_map .copy ())
197- mask_connected = ndi .label (mask )[0 ]
198-
199- labels_connected = np .unique (mask_connected )[1 :]
200- for lab in labels_connected :
201- inst = np .array (mask_connected == lab , copy = True )
202- y1 , y2 , x1 , x2 = bounding_box (inst )
203- y1 = y1 - 2 if y1 - 2 >= 0 else y1
204- x1 = x1 - 2 if x1 - 2 >= 0 else x1
205- x2 = x2 + 2 if x2 + 2 <= mask_connected .shape [1 ] - 1 else x2
206- y2 = y2 + 2 if y2 + 2 <= mask_connected .shape [0 ] - 1 else y2
207-
208- box_insts = mask [y1 :y2 , x1 :x2 ]
209- if len (np .unique (ndi .label (box_insts )[0 ])) <= 2 :
210- real_labels , counts = np .unique (box_insts , return_counts = True )
211- real_labels = real_labels [1 :]
212- counts = counts [1 :]
213- max_pixels = np .max (counts )
214- max_label = real_labels [np .argmax (counts )]
215- for real_lab , count in list (zip (list (real_labels ), list (counts ))):
216- if count < max_pixels :
217- if count < size :
218- mask [mask == real_lab ] = max_label
219-
220- return mask
221-
222-
223191def post_proc_stardist (
224- dist_map : np .ndarray , stardist_map : np .ndarray , thresh : float = 0.4 , ** kwargs
192+ dist_map : np .ndarray ,
193+ stardist_map : np .ndarray ,
194+ score_thresh : float = 0.5 ,
195+ iou_thresh : float = 0.5 ,
196+ trim_bboxes : bool = True ,
197+ ** kwargs ,
225198) -> np .ndarray :
226- """Run post-processing for stardist.
199+ """Run post-processing for stardist outputs.
200+
201+ NOTE: This is not the original cpp version.
202+ This is a python re-implementation of the stardidst post-processing
203+ pipeline that uses non-maximum-suppression. Here, critical parts of the
204+ nms are accelerated with `numba` and `scipy.spatial.KDtree`.
227205
228- NOTE: This is not the original version that uses NMS.
229- This is rather a workaround that is a little slower.
206+ NOTE:
207+ This implementaiton of the stardist post-processing is actually nearly twice
208+ faster than the original version if `trim_bboxes` is set to True. The resulting
209+ segmentation is not an exact match but the differences are mostly neglible.
230210
231211 Parameters
232212 ----------
@@ -236,37 +216,75 @@ def post_proc_stardist(
236216 Predicted radial distances. Shape: (n_rays, H, W).
237217 thresh : float, default=0.4
238218 Threshold for the regressed distance transform.
219+ trim_bboxes : bool, default=True
220+ If True, The non-zero pixels are computed only from the cell contours
221+ which prunes down the pixel search space drastically.
239222
240223 Returns
241224 -------
242225 np.ndarray:
243226 Instance labelled mask. Shape: (H, W).
244227 """
245- stardist_map = stardist_map .transpose (1 , 2 , 0 )
246- mask = _ind_prob_thresh (dist_map , thresh , b = 2 )
247-
248- # invert distmap
249- inv_dist_map = 255 - img_as_ubyte (dist_map )
250-
251- # find markers from minima erosion reconstructed maxima of inv dist map
252- reconstructed = h_minima_reconstruction (inv_dist_map )
253- markers = find_maxima (reconstructed , mask = mask )
254- markers = ndi .label (markers )[0 ]
255- markers = remove_small_objects (markers , min_size = 5 )
256- points = np .array (
257- tuple (np .array (r .centroid ).astype (int ) for r in regionprops (markers ))
258- )
228+ if (
229+ not dist_map .ndim == 2
230+ and not stardist_map .ndim == 3
231+ and not dist_map .shape == stardist_map .shape [:2 ]
232+ ):
233+ raise ValueError (
234+ "Illegal input shapes. Make sure that: "
235+ f"`dist_map` has to have shape: (H, W). Got: { dist_map .shape } "
236+ f"`stardist_map` has to have shape (H, W, nrays). Got: { stardist_map .shape } "
237+ )
259238
260- if len ( points ) == 0 :
261- return np .zeros_like ( mask )
239+ dist = np . asarray ( stardist_map ). transpose ( 1 , 2 , 0 )
240+ prob = np .asarray ( dist_map )
262241
263- dist = stardist_map [ tuple ( points . T )]
264- scores = dist_map [ tuple ( points . T )]
242+ # threshold the edt distance transform map
243+ mask = _ind_prob_thresh ( prob , score_thresh )
265244
266- labels = polygons_to_label (
267- dist , points , prob = scores , shape = mask .shape , scale_dist = (1 , 1 )
245+ # get only the mask contours to trim down bbox search space
246+ if trim_bboxes :
247+ fp = disk (2 )
248+ mask -= erosion (mask , fp )
249+
250+ points = np .stack (np .where (mask ), axis = 1 )
251+
252+ # Get only non-zero pixels of the transforms
253+ dist = dist [mask > 0 ]
254+ scores = prob [mask > 0 ]
255+
256+ # sort descendingly
257+ ind = np .argsort (scores )[::- 1 ]
258+ dist = dist [ind ]
259+ scores = scores [ind ]
260+ points = points [ind ]
261+
262+ # get bounding boxes
263+ x1 , y1 , x2 , y2 , areas , max_dist = get_bboxes (dist , points )
264+ boxes = np .stack ([x1 , y1 , x2 , y2 ], axis = 1 )
265+
266+ # consider only boxes above score threshold
267+ score_cond = scores >= score_thresh
268+ boxes = boxes [score_cond ]
269+ scores = scores [score_cond ]
270+ areas = areas [score_cond ]
271+
272+ # run nms
273+ inds = nms_stardist (
274+ boxes ,
275+ points ,
276+ scores ,
277+ areas ,
278+ max_dist ,
279+ score_threshold = score_thresh ,
280+ iou_threshold = iou_thresh ,
268281 )
269- labels = _clean_up (labels , ** kwargs )
282+
283+ # get the centroids
284+ points = points [inds ]
285+ scores = scores [inds ]
286+ dist = dist [inds ]
287+ labels = polygons_to_label (dist , points , prob = scores , shape = dist_map .shape )
270288
271289 return labels
272290
0 commit comments