1+ from collections .abc import Mapping
2+ from typing import Optional
3+
14import numpy as np
25from scipy .ndimage import binary_dilation
36
47
58class PointAndBoxPromptGenerator :
6- def __init__ (self , n_positive_points , n_negative_points , dilation_strength ,
7- get_point_prompts = True , get_box_prompts = False ):
9+ """Generate point and/or box prompts from an instance segmentation.
10+
11+ You can use this class to derive prompts from an instance segmentation, either for
12+ evaluation purposes or for training Segment Anything on custom data.
13+ In order to use this generator you need to precompute the bounding boxes and center
14+ coordiantes of the instance segmentation, using e.g. `util.get_bounding_boxes_and_centers`.
15+ Here's an example for how to use this class:
16+ ```python
17+ # Initialize generator for 1 positive and 4 negative point prompts.
18+ prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)
19+ # Precompute the bounding boxes for the given segmentation
20+ bounding_boxes, _ = util.get_bounding_boxes_and_centes(segmentation)
21+ # generate point prompts for the object with id 1 in 'segmentation'
22+ seg_id = 1
23+ points, point_labels, _, _ = prompt_generator(segmentation, seg_id, bounding_boxes)
24+ ```
25+
26+ Args:
27+ n_positive_points: The number of positive point prompts to generate per mask.
28+ n_negative_points: The number of negative point prompts to generate per mask.
29+ dilation_strength: The factor by which the mask is dilated before generating prompts.
30+ get_point_prompts: Whether to generate point prompts.
31+ get_box_prompts: Whether to generate box prompts.
32+ """
33+ def __init__ (
34+ self ,
35+ n_positive_points : int ,
36+ n_negative_points : int ,
37+ dilation_strength : int ,
38+ get_point_prompts : bool = True ,
39+ get_box_prompts : bool = False
40+ ):
841 self .n_positive_points = n_positive_points
942 self .n_negative_points = n_negative_points
1043 self .dilation_strength = dilation_strength
1144 self .get_box_prompts = get_box_prompts
1245 self .get_point_prompts = get_point_prompts
1346
1447 if self .get_point_prompts is False and self .get_box_prompts is False :
15- raise ValueError ("You need to request for box/point prompts or both" )
16-
17- def __call__ (self , gt , gt_id , bbox_coordinates , center_coordinates = None ):
18- """
19- Parameters:
20- gt: True Labels
21- gt_id: Instance ID for the Cells
22- center_coordinates: Coordinates for the centroid seed of the cell
23- bbox_coordinates: Bounding box coordinates around the cell
48+ raise ValueError ("You need to request box prompts, point prompts or both." )
49+
50+ def __call__ (
51+ self ,
52+ segmentation : np .ndarray ,
53+ segmentation_id : int ,
54+ bbox_coordinates : Mapping [int , tuple ],
55+ center_coordinates : Optional [Mapping [int , np .ndarray ]] = None
56+ ) -> tuple [
57+ Optional [list [tuple ]], Optional [list [int ]], Optional [list [tuple ]], np .ndarray
58+ ]:
59+ """Generate the prompts for one object in the segmentation.
60+
61+ Args:
62+ segmentation: The instance segmentation.
63+ segmentation_id: The ID of the instance.
64+ bbox_coordinates: The precomputed bounding boxes of all objects in the segmentation.
65+ center_coordinates: The precomputed center coordinates of all objects in the segmentation.
66+ If passed, these coordinates will be used as the first positive point prompt.
67+ If not passed a random point from within the object mask will be used.
68+
69+ Returns:
70+ List of point coordinates. Returns None, if get_point_prompts is false.
71+ List of point labels. Returns None, if get_point_prompts is false.
72+ List containing the object bounding box. Returns None, if get_box_prompts is false.
73+ Object mask.
2474 """
2575 coord_list = []
2676 label_list = []
@@ -30,7 +80,8 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
3080 coord_list .append (tuple (map (int , center_coordinates ))) # to get int coords instead of float
3181 label_list .append (1 )
3282
33- # getting the additional positive points by randomly sampling points from this mask except the center coordinate
83+ # getting the additional positive points by randomly sampling points
84+ # from this mask except the center coordinate
3485 n_positive_remaining = self .n_positive_points - 1
3586
3687 else :
@@ -40,7 +91,7 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
4091 if self .get_box_prompts :
4192 bbox_list = [bbox_coordinates ]
4293
43- object_mask = gt == gt_id
94+ object_mask = segmentation == segmentation_id
4495
4596 if n_positive_remaining > 0 :
4697 # all coordinates of our current object
@@ -50,9 +101,10 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
50101 n_coordinates = len (object_coordinates [0 ])
51102
52103 # randomly sampling n_positive_remaining_points from these coordinates
53- positive_indices = np .random .choice (n_coordinates , replace = False ,
54- size = min (n_positive_remaining , n_coordinates ) # handles the cases with insufficient fg pixels
55- )
104+ positive_indices = np .random .choice (
105+ n_coordinates , replace = False ,
106+ size = min (n_positive_remaining , n_coordinates ) # handles the cases with insufficient fg pixels
107+ )
56108 for positive_index in positive_indices :
57109 positive_coordinates = int (object_coordinates [0 ][positive_index ]), \
58110 int (object_coordinates [1 ][positive_index ])
@@ -63,10 +115,10 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
63115 # getting the negative points
64116 # for this we do the opposite and we set the mask to the bounding box - the object mask
65117 # we need to dilate the object mask before doing this: we use scipy.ndimage.binary_dilation for this
66- dilated_object = binary_dilation (object_mask , iterations = self .dilation_strength )
67- background_mask = np .zeros (gt .shape )
118+ dilated_object = binary_dilation (object_mask , iterations = self .dilation_strensegmentationh )
119+ background_mask = np .zeros (segmentation .shape )
68120 background_mask [bbox_coordinates [0 ]:bbox_coordinates [2 ], bbox_coordinates [1 ]:bbox_coordinates [3 ]] = 1
69- background_mask = binary_dilation (background_mask , iterations = self .dilation_strength )
121+ background_mask = binary_dilation (background_mask , iterations = self .dilation_strensegmentationh )
70122 background_mask = abs (
71123 background_mask .astype (np .float32 ) - dilated_object .astype (np .float32 )
72124 ) # casting booleans to do subtraction
@@ -80,9 +132,10 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
80132 n_coordinates = len (background_coordinates [0 ])
81133
82134 # randomly sample n_positive_remaining_points from these coordinates
83- negative_indices = np .random .choice (n_coordinates , replace = False ,
84- size = min (n_negative_remaining , n_coordinates ) # handles the cases with insufficient bg pixels
85- )
135+ negative_indices = np .random .choice (
136+ n_coordinates , replace = False ,
137+ size = min (n_negative_remaining , n_coordinates ) # handles the cases with insufficient bg pixels
138+ )
86139 for negative_index in negative_indices :
87140 negative_coordinates = int (background_coordinates [0 ][negative_index ]), \
88141 int (background_coordinates [1 ][negative_index ])
0 commit comments