Skip to content

Commit 5a06e0c

Browse files
authored
Make ZSL VPT masks refinement configurable (#203)
* Make ZSL VPT masks refinement configurable * Move extra decoding flag to infer
1 parent 449ced0 commit 5a06e0c

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

model_api/python/model_api/models/visual_prompting.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,16 @@ def __call__(
313313
self,
314314
image: np.ndarray,
315315
reference_features: VisualPromptingFeatures | None = None,
316+
apply_masks_refinement: bool = True,
316317
) -> ZSLVisualPromptingResult:
317318
"""A wrapper of the SAMLearnableVisualPrompter.infer() method"""
318-
return self.infer(image, reference_features)
319+
return self.infer(image, reference_features, apply_masks_refinement)
319320

320321
def infer(
321322
self,
322323
image: np.ndarray,
323324
reference_features: VisualPromptingFeatures | None = None,
325+
apply_masks_refinement: bool = True,
324326
) -> ZSLVisualPromptingResult:
325327
"""
326328
Obtains masks by already prepared reference features.
@@ -332,6 +334,8 @@ def infer(
332334
image (np.ndarray): HWC-shaped image
333335
reference_features (VisualPromptingFeatures | None, optional): Reference features object obtained during previous learn() calls.
334336
If not passed, object internal state is used, which reflects the last learn() call. Defaults to None.
337+
apply_masks_refinement (bool, optional): Flag controlling additional refinement stage on inference. Once enabled, decoder will
338+
be launched 2 extra times to refine the masks obtained with the first decoder call. Defaults to True.
335339
336340
Returns:
337341
ZSLVisualPromptingResult: Mapping label -> predicted mask. Each mask object contains a list of binary masks, and a list of
@@ -401,7 +405,9 @@ def infer(
401405
}
402406
inputs_decoder["image_embeddings"] = image_embeddings
403407

404-
prediction = self._predict_masks(inputs_decoder, original_shape, True)
408+
prediction = self._predict_masks(
409+
inputs_decoder, original_shape, apply_masks_refinement
410+
)
405411
prediction.update({"scores": points_score[-1]})
406412

407413
predicted_masks[label].append(prediction[self.decoder.output_blob_name])

0 commit comments

Comments
 (0)