Skip to content

Commit 74e6f3a

Browse files
Update type annotations
1 parent 343361f commit 74e6f3a

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

micro_sam/evaluation/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def _run_inference_with_iterative_prompting_for_image(
458458
image_embeddings = model.image_embeddings_oft(input_images)
459459

460460
multimasking = n_pos == 1
461-
prompt_generator = IterativePromptGenerator(device)
461+
prompt_generator = IterativePromptGenerator()
462462

463463
n_samples = len(sampled_ids[0])
464464
n_batches = int(np.ceil(float(n_samples) / batch_size))

micro_sam/prompt_generators.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from collections.abc import Mapping
7-
from typing import Optional
7+
from typing import Optional, Tuple
88

99
import numpy as np
1010
from scipy.ndimage import binary_dilation
@@ -195,10 +195,7 @@ def __call__(
195195
class IterativePromptGenerator:
196196
"""Generate point prompts from an instance segmentation iteratively.
197197
"""
198-
def __init__(self, device=None):
199-
self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
200-
201-
def get_positive_points(self, pos_region, overlap_region):
198+
def _get_positive_points(self, pos_region, overlap_region):
202199
positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
203200
# we may have objects withput a positive region (= missing true foreground)
204201
# in this case we just sample a point where the model was already correct
@@ -220,9 +217,10 @@ def get_positive_points(self, pos_region, overlap_region):
220217
return pos_coordinates, pos_labels
221218

222219
# TODO get rid of this looped implementation and use proper batched computation instead
223-
def get_negative_points(self, negative_region_batched, true_object_batched, gt_batched):
224-
negative_coordinates, negative_labels = [], []
220+
def _get_negative_points(self, negative_region_batched, true_object_batched, gt_batched):
221+
device = negative_region_batched.device
225222

223+
negative_coordinates, negative_labels = [], []
226224
for neg_region, true_object, gt in zip(negative_region_batched, true_object_batched, gt_batched):
227225

228226
tmp_neg_loc = torch.where(neg_region)
@@ -233,11 +231,11 @@ def get_negative_points(self, negative_region_batched, true_object_batched, gt_b
233231
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
234232
bbox_mask = torch.zeros_like(true_object).squeeze(0)
235233
bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1
236-
bbox_mask = bbox_mask[None].to(self.device)
234+
bbox_mask = bbox_mask[None].to(device)
237235

238236
# NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
239237
# TODO: just expand the pixels of bbox
240-
dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(self.device)).squeeze(0)
238+
dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(device)).squeeze(0)
241239
background_mask = abs(dilated_bbox_mask - true_object)
242240
tmp_neg_loc = torch.where(background_mask)
243241

@@ -258,25 +256,38 @@ def get_negative_points(self, negative_region_batched, true_object_batched, gt_b
258256

259257
def __call__(
260258
self,
261-
gt,
262-
object_mask,
263-
current_points,
264-
current_labels
265-
):
259+
gt: torch.Tensor,
260+
object_mask: torch.Tensor,
261+
current_points: torch.Tensor,
262+
current_labels: torch.Tensor
263+
) -> Tuple[torch.Tensor, torch.Tensor]:
266264
"""Generate the prompts for each object iteratively in the segmentation.
265+
266+
Args:
267+
The groundtruth segmentation.
268+
The predicted objects.
269+
The current points.
270+
Thr current labels.
271+
272+
Returns:
273+
The updated point prompt coordinates.
274+
The updated point prompt labels.
267275
"""
268276
assert gt.shape == object_mask.shape
269-
true_object = gt.to(self.device)
277+
device = object_mask.device
278+
279+
true_object = gt.to(device)
270280
expected_diff = (object_mask - true_object)
271281
neg_region = (expected_diff == 1).to(torch.float)
272282
pos_region = (expected_diff == -1)
273283
overlap_region = torch.logical_and(object_mask == 1, true_object == 1).to(torch.float32)
274284

275-
pos_coordinates, pos_labels = self.get_positive_points(pos_region, overlap_region)
276-
neg_coordinates, neg_labels = self.get_negative_points(neg_region, true_object, gt)
285+
pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region)
286+
neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, gt)
277287
assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels)
278288

279-
pos_coordinates, neg_coordinates = torch.tensor(pos_coordinates)[:, None], torch.tensor(neg_coordinates)[:, None]
289+
pos_coordinates = torch.tensor(pos_coordinates)[:, None]
290+
neg_coordinates = torch.tensor(neg_coordinates)[:, None]
280291
pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None]
281292

282293
net_coords = torch.cat([current_points, pos_coordinates, neg_coordinates], dim=1)

0 commit comments

Comments
 (0)