Skip to content

Commit 980f942

Browse files
Merge pull request #120 from computational-cell-analytics/multi-output
Implement use_best_multimask for segment_from_points
2 parents b9d9beb + 4aa8fba commit 980f942

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

micro_sam/prompt_based_segmentation.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def segment_from_points(
251251
i: Optional[int] = None,
252252
multimask_output: bool = False,
253253
return_all: bool = False,
254+
use_best_multimask: Optional[bool] = None,
254255
):
255256
"""Segmentation from point prompts.
256257
@@ -264,6 +265,8 @@ def segment_from_points(
264265
or a time dimension and two spatial dimensions.
265266
multimask_output: Whether to return multiple or just a single mask.
266267
return_all: Whether to return the score and logits in addition to the mask.
268+
use_best_multimask: Whether to use multimask output and then choose the best mask.
269+
By default this is used for a single positive point and not otherwise.
267270
268271
Returns:
269272
The binary segmentation mask.
@@ -273,13 +276,21 @@ def segment_from_points(
273276
)
274277
points, labels = prompts
275278

279+
if use_best_multimask is None:
280+
use_best_multimask = len(points) == 1 and labels[0] == 1
281+
multimask_output_ = multimask_output or use_best_multimask
282+
276283
# predict the mask
277284
mask, scores, logits = predictor.predict(
278285
point_coords=points[:, ::-1], # SAM has reversed XY conventions
279286
point_labels=labels,
280-
multimask_output=multimask_output,
287+
multimask_output=multimask_output_,
281288
)
282289

290+
if use_best_multimask:
291+
best_mask_id = np.argmax(scores)
292+
mask = mask[best_mask_id][None]
293+
283294
if tile is not None:
284295
return _tile_to_full_mask(mask, shape, tile, return_all, multimask_output)
285296
if return_all:

test/test_prompt_based_segmentation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,20 @@ def setUpClass(cls):
3333
def test_segment_from_points(self):
3434
from micro_sam.prompt_based_segmentation import segment_from_points
3535

36+
# segment with one positive and four negative points
3637
points = np.array([[128, 128], [64, 64], [192, 192], [64, 192], [192, 64]])
3738
labels = np.array([1, 0, 0, 0, 0])
3839

3940
predicted = segment_from_points(self.predictor, points, labels)
4041
self.assertGreater(util.compute_iou(self.mask, predicted), 0.9)
4142

43+
# segment with one positive point, using the best multimask
44+
points = np.array([[128, 128]])
45+
labels = np.array([1])
46+
47+
predicted = segment_from_points(self.predictor, points, labels)
48+
self.assertGreater(util.compute_iou(self.mask, predicted), 0.9)
49+
4250
def _test_segment_from_mask(self, shape=(256, 256)):
4351
from micro_sam.prompt_based_segmentation import segment_from_mask
4452

0 commit comments

Comments
 (0)