44"""
55
66from collections .abc import Mapping
7- from typing import Optional
7+ from typing import Optional , Tuple
88
99import numpy as np
1010from scipy .ndimage import binary_dilation
@@ -195,10 +195,7 @@ def __call__(
195195class IterativePromptGenerator :
196196 """Generate point prompts from an instance segmentation iteratively.
197197 """
198- def __init__ (self , device = None ):
199- self .device = device if device is not None else "cuda" if torch .cuda .is_available () else "cpu"
200-
201- def get_positive_points (self , pos_region , overlap_region ):
198+ def _get_positive_points (self , pos_region , overlap_region ):
202199 positive_locations = [torch .where (pos_reg ) for pos_reg in pos_region ]
203200 # we may have objects withput a positive region (= missing true foreground)
204201 # in this case we just sample a point where the model was already correct
@@ -220,9 +217,10 @@ def get_positive_points(self, pos_region, overlap_region):
220217 return pos_coordinates , pos_labels
221218
222219 # TODO get rid of this looped implementation and use proper batched computation instead
223- def get_negative_points (self , negative_region_batched , true_object_batched , gt_batched ):
224- negative_coordinates , negative_labels = [], []
220+ def _get_negative_points (self , negative_region_batched , true_object_batched , gt_batched ):
221+ device = negative_region_batched . device
225222
223+ negative_coordinates , negative_labels = [], []
226224 for neg_region , true_object , gt in zip (negative_region_batched , true_object_batched , gt_batched ):
227225
228226 tmp_neg_loc = torch .where (neg_region )
@@ -233,11 +231,11 @@ def get_negative_points(self, negative_region_batched, true_object_batched, gt_b
233231 torch .max (x_coords ) + 1 , torch .max (y_coords ) + 1 ])
234232 bbox_mask = torch .zeros_like (true_object ).squeeze (0 )
235233 bbox_mask [bbox [0 ]:bbox [2 ], bbox [1 ]:bbox [3 ]] = 1
236- bbox_mask = bbox_mask [None ].to (self . device )
234+ bbox_mask = bbox_mask [None ].to (device )
237235
238236 # NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
239237 # TODO: just expand the pixels of bbox
240- dilated_bbox_mask = dilation (bbox_mask [None ], torch .ones (3 , 3 ).to (self . device )).squeeze (0 )
238+ dilated_bbox_mask = dilation (bbox_mask [None ], torch .ones (3 , 3 ).to (device )).squeeze (0 )
241239 background_mask = abs (dilated_bbox_mask - true_object )
242240 tmp_neg_loc = torch .where (background_mask )
243241
@@ -258,25 +256,38 @@ def get_negative_points(self, negative_region_batched, true_object_batched, gt_b
258256
259257 def __call__ (
260258 self ,
261- gt ,
262- object_mask ,
263- current_points ,
264- current_labels
265- ):
259+ gt : torch . Tensor ,
260+ object_mask : torch . Tensor ,
261+ current_points : torch . Tensor ,
262+ current_labels : torch . Tensor
263+ ) -> Tuple [ torch . Tensor , torch . Tensor ] :
266264 """Generate the prompts for each object iteratively in the segmentation.
265+
266+ Args:
267+ The groundtruth segmentation.
268+ The predicted objects.
269+ The current points.
270+ Thr current labels.
271+
272+ Returns:
273+ The updated point prompt coordinates.
274+ The updated point prompt labels.
267275 """
268276 assert gt .shape == object_mask .shape
269- true_object = gt .to (self .device )
277+ device = object_mask .device
278+
279+ true_object = gt .to (device )
270280 expected_diff = (object_mask - true_object )
271281 neg_region = (expected_diff == 1 ).to (torch .float )
272282 pos_region = (expected_diff == - 1 )
273283 overlap_region = torch .logical_and (object_mask == 1 , true_object == 1 ).to (torch .float32 )
274284
275- pos_coordinates , pos_labels = self .get_positive_points (pos_region , overlap_region )
276- neg_coordinates , neg_labels = self .get_negative_points (neg_region , true_object , gt )
285+ pos_coordinates , pos_labels = self ._get_positive_points (pos_region , overlap_region )
286+ neg_coordinates , neg_labels = self ._get_negative_points (neg_region , true_object , gt )
277287 assert len (pos_coordinates ) == len (pos_labels ) == len (neg_coordinates ) == len (neg_labels )
278288
279- pos_coordinates , neg_coordinates = torch .tensor (pos_coordinates )[:, None ], torch .tensor (neg_coordinates )[:, None ]
289+ pos_coordinates = torch .tensor (pos_coordinates )[:, None ]
290+ neg_coordinates = torch .tensor (neg_coordinates )[:, None ]
280291 pos_labels , neg_labels = torch .tensor (pos_labels )[:, None ], torch .tensor (neg_labels )[:, None ]
281292
282293 net_coords = torch .cat ([current_points , pos_coordinates , neg_coordinates ], dim = 1 )
0 commit comments