@@ -199,46 +199,62 @@ def __init__(self, device=None):
199199 self .device = device if device is not None else "cuda" if torch .cuda .is_available () else "cpu"
200200
201201 def get_positive_points (self , pos_region , overlap_region ):
202- tmp_pos_loc = torch .where (pos_region )
203- # condiion below where there is no room for improvement for the model
204- # hence we put a positive point in the "already correct" regions
205- if torch .stack (tmp_pos_loc ).shape [- 1 ] == 0 :
206- tmp_pos_loc = torch .where (overlap_region )
207-
208- pos_index = np .random .choice (len (tmp_pos_loc [1 ]))
209- pos_coordinates = int (tmp_pos_loc [1 ][pos_index ]), int (tmp_pos_loc [2 ][pos_index ])
210- pos_coordinates = pos_coordinates [::- 1 ]
211- pos_labels = 1
202+ positive_locations = [torch .where (pos_reg ) for pos_reg in pos_region ]
203+ # we may have objects withput a positive region (= missing true foreground)
204+ # in this case we just sample a point where the model was already correct
205+ positive_locations = [
206+ torch .where (ovlp_reg ) if len (pos_loc [0 ]) == 0 else pos_loc
207+ for pos_loc , ovlp_reg in zip (positive_locations , overlap_region )
208+ ]
209+ # we sample one location for each object in the batch
210+ sampled_indices = [np .random .choice (len (pos_loc [0 ])) for pos_loc in positive_locations ]
211+ # get the corresponding coordinates (Note that we flip the axis order here due to the expected order of SAM)
212+ pos_coordinates = [
213+ [pos_loc [- 1 ][idx ], pos_loc [- 2 ][idx ]] for pos_loc , idx in zip (positive_locations , sampled_indices )
214+ ]
215+
216+ # make sure that we still have the correct batch size
217+ assert len (pos_coordinates ) == pos_region .shape [0 ]
218+ pos_labels = [1 ] * len (pos_coordinates )
219+
212220 return pos_coordinates , pos_labels
213221
214- def get_negative_points (self , neg_region , true_object , gt ):
215- tmp_neg_loc = torch .where (neg_region )
216- if torch .stack (tmp_neg_loc ).shape [- 1 ] == 0 :
217- tmp_true_loc = torch .where (true_object )
218- x_coords , y_coords = tmp_true_loc [1 ], tmp_true_loc [2 ]
219- bbox = torch .stack ([torch .min (x_coords ), torch .min (y_coords ),
220- torch .max (x_coords ) + 1 , torch .max (y_coords ) + 1 ])
221- bbox_mask = torch .zeros_like (true_object ).squeeze (0 )
222- bbox_mask [bbox [0 ]:bbox [2 ], bbox [1 ]:bbox [3 ]] = 1
223- bbox_mask = bbox_mask [None ].to (self .device )
224-
225- # NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
226- # TODO: just expand the pixels of bbox
227- dilated_bbox_mask = dilation (bbox_mask [None ], torch .ones (3 , 3 ).to (self .device )).squeeze (0 )
228- background_mask = abs (dilated_bbox_mask - true_object )
229- tmp_neg_loc = torch .where (background_mask )
230-
231- # there is a chance that the object is small to not return a decent-sized bounding box
232- # hence we might not find points sometimes there as well, hence we sample points from true background
233- if torch .stack (tmp_neg_loc ).shape [- 1 ] == 0 :
234- tmp_neg_loc = torch .where (gt == 0 )
222+ # TODO get rid of this looped implementation and use proper batched computation instead
223+ def get_negative_points (self , negative_region_batched , true_object_batched , gt_batched ):
224+ negative_coordinates , negative_labels = [], []
235225
236- neg_index = np .random .choice (len (tmp_neg_loc [1 ]))
237- neg_coordinates = int (tmp_neg_loc [1 ][neg_index ]), int (tmp_neg_loc [2 ][neg_index ])
238- neg_coordinates = neg_coordinates [::- 1 ]
239- neg_labels = 0
226+ for neg_region , true_object , gt in zip (negative_region_batched , true_object_batched , gt_batched ):
240227
241- return neg_coordinates , neg_labels
228+ tmp_neg_loc = torch .where (neg_region )
229+ if torch .stack (tmp_neg_loc ).shape [- 1 ] == 0 :
230+ tmp_true_loc = torch .where (true_object )
231+ x_coords , y_coords = tmp_true_loc [1 ], tmp_true_loc [2 ]
232+ bbox = torch .stack ([torch .min (x_coords ), torch .min (y_coords ),
233+ torch .max (x_coords ) + 1 , torch .max (y_coords ) + 1 ])
234+ bbox_mask = torch .zeros_like (true_object ).squeeze (0 )
235+ bbox_mask [bbox [0 ]:bbox [2 ], bbox [1 ]:bbox [3 ]] = 1
236+ bbox_mask = bbox_mask [None ].to (self .device )
237+
238+ # NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
239+ # TODO: just expand the pixels of bbox
240+ dilated_bbox_mask = dilation (bbox_mask [None ], torch .ones (3 , 3 ).to (self .device )).squeeze (0 )
241+ background_mask = abs (dilated_bbox_mask - true_object )
242+ tmp_neg_loc = torch .where (background_mask )
243+
244+ # there is a chance that the object is small to not return a decent-sized bounding box
245+ # hence we might not find points sometimes there as well, hence we sample points from true background
246+ if torch .stack (tmp_neg_loc ).shape [- 1 ] == 0 :
247+ tmp_neg_loc = torch .where (gt == 0 )
248+
249+ neg_index = np .random .choice (len (tmp_neg_loc [1 ]))
250+ neg_coordinates = [tmp_neg_loc [1 ][neg_index ], tmp_neg_loc [2 ][neg_index ]]
251+ neg_coordinates = neg_coordinates [::- 1 ]
252+ neg_labels = 0
253+
254+ negative_coordinates .append (neg_coordinates )
255+ negative_labels .append (neg_labels )
256+
257+ return negative_coordinates , negative_labels
242258
243259 def __call__ (
244260 self ,
@@ -249,6 +265,7 @@ def __call__(
249265 ):
250266 """Generate the prompts for each object iteratively in the segmentation.
251267 """
268+ assert gt .shape == object_mask .shape
252269 true_object = gt .to (self .device )
253270 expected_diff = (object_mask - true_object )
254271 neg_region = (expected_diff == 1 ).to (torch .float )
@@ -257,8 +274,12 @@ def __call__(
257274
258275 pos_coordinates , pos_labels = self .get_positive_points (pos_region , overlap_region )
259276 neg_coordinates , neg_labels = self .get_negative_points (neg_region , true_object , gt )
277+ assert len (pos_coordinates ) == len (pos_labels ) == len (neg_coordinates ) == len (neg_labels )
278+
279+ pos_coordinates , neg_coordinates = torch .tensor (pos_coordinates )[:, None ], torch .tensor (neg_coordinates )[:, None ]
280+ pos_labels , neg_labels = torch .tensor (pos_labels )[:, None ], torch .tensor (neg_labels )[:, None ]
260281
261- net_coords = torch .cat ([current_points , torch . tensor ([[ pos_coordinates , neg_coordinates ]]) ], dim = 1 )
262- net_labels = torch .cat ([current_labels , torch . tensor ([[ pos_labels , neg_labels ]]) ], dim = 1 )
282+ net_coords = torch .cat ([current_points , pos_coordinates , neg_coordinates ], dim = 1 )
283+ net_labels = torch .cat ([current_labels , pos_labels , neg_labels ], dim = 1 )
263284
264285 return net_coords , net_labels
0 commit comments