Skip to content

Commit cd4a1ea

Browse files
Merge pull request #61 from computational-cell-analytics/update-prompter
Update Prompt Generator - First Point as Any FG Point
2 parents 3087d03 + 8d52cdc commit cd4a1ea

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

micro_sam/prompt_generators.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, n_positive_points, n_negative_points, dilation_strength,
1414
if self.get_point_prompts is False and self.get_box_prompts is False:
1515
raise ValueError("You need to request for box/point prompts or both")
1616

17-
def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
17+
def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
1818
"""
1919
Parameters:
2020
gt: True Labels
@@ -25,17 +25,23 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
2525
coord_list = []
2626
label_list = []
2727

28-
# getting the center coordinate as the first positive point
29-
coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float
30-
label_list.append(1)
28+
if center_coordinates is not None:
29+
# getting the center coordinate as the first positive point (OPTIONAL)
30+
coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float
31+
label_list.append(1)
32+
33+
# getting the additional positive points by randomly sampling points from this mask except the center coordinate
34+
n_positive_remaining = self.n_positive_points - 1
35+
36+
else:
37+
# need to sample "self.n_positive_points" number of points
38+
n_positive_remaining = self.n_positive_points
3139

3240
if self.get_box_prompts:
3341
bbox_list = [bbox_coordinates]
3442

3543
object_mask = gt == gt_id
3644

37-
# getting the additional positive points by randomly sampling points from this mask
38-
n_positive_remaining = self.n_positive_points - 1
3945
if n_positive_remaining > 0:
4046
# all coordinates of our current object
4147
object_coordinates = np.where(object_mask)

test/test_prompt_generators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_point_prompt_generator(self):
5353
generator = PointAndBoxPromptGenerator(n_pos, n_neg, dilation_strength=4)
5454
for label_id in label_ids:
5555
center, box = centers.get(label_id), boxes.get(label_id)
56-
coords, point_labels, _, _ = generator(labels, label_id, center, box)
56+
coords, point_labels, _, _ = generator(labels, label_id, box, center)
5757
coords_ = (np.array([int(coo[0]) for coo in coords]),
5858
np.array([int(coo[1]) for coo in coords]))
5959
mask = labels == label_id
@@ -80,7 +80,7 @@ def test_box_prompt_generator(self):
8080

8181
for label_id in label_ids:
8282
center, box_ = centers.get(label_id), boxes.get(label_id)
83-
_, _, box, _ = generator(labels, label_id, center, box_)
83+
_, _, box, _ = generator(labels, label_id, box_, center)
8484
coords = np.where(labels == label_id)
8585
expected_box = [coo.min() for coo in coords] + [coo.max() + 1 for coo in coords]
8686
self.assertEqual(expected_box, list(box[0]))

0 commit comments

Comments
 (0)