Skip to content

Commit 56844fb

Browse files
committed
Generalise prompting condition for all objects
1 parent 2668e52 commit 56844fb

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

micro_sam/prompt_generators.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,16 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
4343
# ([x1, x2, ...], [y1, y2, ...])
4444
n_coordinates = len(object_coordinates[0])
4545

46-
if n_coordinates > n_positive_remaining: # for some cases, there aren't many forground object_coordinates
47-
# randomly sampling n_positive_remaining_points from these coordinates
48-
positive_indices = np.random.choice(n_coordinates, replace=False, size=n_positive_remaining)
49-
for positive_index in positive_indices:
50-
positive_coordinates = int(object_coordinates[0][positive_index]), \
51-
int(object_coordinates[1][positive_index])
52-
53-
coord_list.append(positive_coordinates)
54-
label_list.append(1)
55-
else:
56-
print(f"{n_coordinates} foreground pixel spotted..")
46+
# randomly sampling n_positive_remaining_points from these coordinates
47+
positive_indices = np.random.choice(n_coordinates, replace=False,
48+
size=min(n_positive_remaining, n_coordinates) # handles the cases with insufficient fg pixels
49+
)
50+
for positive_index in positive_indices:
51+
positive_coordinates = int(object_coordinates[0][positive_index]), \
52+
int(object_coordinates[1][positive_index])
53+
54+
coord_list.append(positive_coordinates)
55+
label_list.append(1)
5756

5857
# getting the negative points
5958
# for this we do the opposite and we set the mask to the bounding box - the object mask
@@ -74,17 +73,16 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
7473
# ([x1, x2, ...], [y1, y2, ...])
7574
n_coordinates = len(background_coordinates[0])
7675

77-
if n_coordinates > n_negative_remaining: # for some cases, there aren't many background object_coordinates
78-
# randomly sample n_positive_remaining_points from these coordinates
79-
negative_indices = np.random.choice(n_coordinates, replace=False, size=n_negative_remaining)
80-
for negative_index in negative_indices:
81-
negative_coordinates = int(background_coordinates[0][negative_index]), \
82-
int(background_coordinates[1][negative_index])
83-
84-
coord_list.append(negative_coordinates)
85-
label_list.append(0)
86-
else:
87-
print(f"{n_coordinates} background pixel spotted..")
76+
# randomly sample n_positive_remaining_points from these coordinates
77+
negative_indices = np.random.choice(n_coordinates, replace=False,
78+
size=min(n_negative_remaining, n_coordinates) # handles the cases with insufficient bg pixels
79+
)
80+
for negative_index in negative_indices:
81+
negative_coordinates = int(background_coordinates[0][negative_index]), \
82+
int(background_coordinates[1][negative_index])
83+
84+
coord_list.append(negative_coordinates)
85+
label_list.append(0)
8886

8987
# returns object-level masks per instance for cross-verification (TODO: fix it later)
9088
if self.get_point_prompts is True and self.get_box_prompts is True: # we want points and box

0 commit comments

Comments
 (0)