Skip to content

Commit 2668e52

Browse files
committed
Add Condition for Background Points
1 parent ebfc4be commit 2668e52

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

micro_sam/prompt_generators.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
4343
# ([x1, x2, ...], [y1, y2, ...])
4444
n_coordinates = len(object_coordinates[0])
4545

46-
if n_coordinates > n_positive_remaining: # for some cases, there aren't any forground object_coordinates
46+
if n_coordinates > n_positive_remaining: # for some cases, there aren't many forground object_coordinates
4747
# randomly sampling n_positive_remaining_points from these coordinates
4848
positive_indices = np.random.choice(n_coordinates, replace=False, size=n_positive_remaining)
4949
for positive_index in positive_indices:
@@ -53,7 +53,7 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
5353
coord_list.append(positive_coordinates)
5454
label_list.append(1)
5555
else:
56-
print(f"{n_coordinates} fg spotted..")
56+
print(f"{n_coordinates} foreground pixel spotted..")
5757

5858
# getting the negative points
5959
# for this we do the opposite and we set the mask to the bounding box - the object mask
@@ -74,14 +74,17 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
7474
# ([x1, x2, ...], [y1, y2, ...])
7575
n_coordinates = len(background_coordinates[0])
7676

77-
# randomly sample n_positive_remaining_points from these coordinates
78-
negative_indices = np.random.choice(n_coordinates, replace=False, size=n_negative_remaining)
79-
for negative_index in negative_indices:
80-
negative_coordinates = int(background_coordinates[0][negative_index]), \
81-
int(background_coordinates[1][negative_index])
77+
if n_coordinates > n_negative_remaining: # for some cases, there aren't many background object_coordinates
78+
# randomly sample n_positive_remaining_points from these coordinates
79+
negative_indices = np.random.choice(n_coordinates, replace=False, size=n_negative_remaining)
80+
for negative_index in negative_indices:
81+
negative_coordinates = int(background_coordinates[0][negative_index]), \
82+
int(background_coordinates[1][negative_index])
8283

83-
coord_list.append(negative_coordinates)
84-
label_list.append(0)
84+
coord_list.append(negative_coordinates)
85+
label_list.append(0)
86+
else:
87+
print(f"{n_coordinates} background pixel spotted..")
8588

8689
# returns object-level masks per instance for cross-verification (TODO: fix it later)
8790
if self.get_point_prompts is True and self.get_box_prompts is True: # we want points and box

0 commit comments

Comments
 (0)