@@ -313,14 +313,16 @@ def __call__(
313313 self ,
314314 image : np .ndarray ,
315315 reference_features : VisualPromptingFeatures | None = None ,
316+ apply_masks_refinement : bool = True ,
316317 ) -> ZSLVisualPromptingResult :
317318 """A wrapper of the SAMLearnableVisualPrompter.infer() method"""
318- return self .infer (image , reference_features )
319+ return self .infer (image , reference_features , apply_masks_refinement )
319320
320321 def infer (
321322 self ,
322323 image : np .ndarray ,
323324 reference_features : VisualPromptingFeatures | None = None ,
325+ apply_masks_refinement : bool = True ,
324326 ) -> ZSLVisualPromptingResult :
325327 """
326328 Obtains masks by already prepared reference features.
@@ -332,6 +334,8 @@ def infer(
332334 image (np.ndarray): HWC-shaped image
333335 reference_features (VisualPromptingFeatures | None, optional): Reference features object obtained during previous learn() calls.
334336 If not passed, object internal state is used, which reflects the last learn() call. Defaults to None.
337+ apply_masks_refinement (bool, optional): Flag controlling additional refinement stage on inference. Once enabled, decoder will
338+ be launched 2 extra times to refine the masks obtained with the first decoder call. Defaults to True.
335339
336340 Returns:
337341 ZSLVisualPromptingResult: Mapping label -> predicted mask. Each mask object contains a list of binary masks, and a list of
@@ -401,7 +405,9 @@ def infer(
401405 }
402406 inputs_decoder ["image_embeddings" ] = image_embeddings
403407
404- prediction = self ._predict_masks (inputs_decoder , original_shape , True )
408+ prediction = self ._predict_masks (
409+ inputs_decoder , original_shape , apply_masks_refinement
410+ )
405411 prediction .update ({"scores" : points_score [- 1 ]})
406412
407413 predicted_masks [label ].append (prediction [self .decoder .output_blob_name ])
0 commit comments