@@ -251,6 +251,7 @@ def segment_from_points(
251251 i : Optional [int ] = None ,
252252 multimask_output : bool = False ,
253253 return_all : bool = False ,
254+ use_best_multimask : Optional [bool ] = None ,
254255):
255256 """Segmentation from point prompts.
256257
@@ -264,6 +265,8 @@ def segment_from_points(
264265 or a time dimension and two spatial dimensions.
265266 multimask_output: Whether to return multiple or just a single mask.
266267 return_all: Whether to return the score and logits in addition to the mask.
268+ use_best_multimask: Whether to use multimask output and then choose the best mask.
269+ By default this is used for a single positive point and not otherwise.
267270
268271 Returns:
269272 The binary segmentation mask.
@@ -273,13 +276,21 @@ def segment_from_points(
273276 )
274277 points , labels = prompts
275278
279+ if use_best_multimask is None :
280+ use_best_multimask = len (points ) == 1 and labels [0 ] == 1
281+ multimask_output_ = multimask_output or use_best_multimask
282+
276283 # predict the mask
277284 mask , scores , logits = predictor .predict (
278285 point_coords = points [:, ::- 1 ], # SAM has reversed XY conventions
279286 point_labels = labels ,
280- multimask_output = multimask_output ,
287+ multimask_output = multimask_output_ ,
281288 )
282289
290+ if use_best_multimask :
291+ best_mask_id = np .argmax (scores )
292+ mask = mask [best_mask_id ][None ]
293+
283294 if tile is not None :
284295 return _tile_to_full_mask (mask , shape , tile , return_all , multimask_output )
285296 if return_all :
0 commit comments