@@ -143,7 +143,8 @@ def __init__(
143143 encoder_model (SAMImageEncoder): initialized decoder wrapper
144144 decoder_model (SAMDecoder): initialized encoder wrapper
145145 reference_features (VisualPromptingFeatures | None, optional): Previously generated reference features.
146- Once the features are passed, one can skip learn() method, and start predicting masks right away. Defaults to None.
146+ Once the features are passed, one can skip learn() method, and start predicting masks right away.
147+ Defaults to None.
147148 threshold (float, optional): Threshold to match vs reference features on infer(). Greater value means a
148149 stricter matching. Defaults to 0.65.
149150 """
@@ -213,7 +214,8 @@ def learn(
213214 reset_features (bool, optional): Forces learning from scratch. Defaults to False.
214215
215216 Returns:
216- tuple[VisualPromptingFeatures, np.ndarray]: return values are the updated VPT reference features and reference masks.
217+ tuple[VisualPromptingFeatures, np.ndarray]: return values are the updated VPT reference features and
218+ reference masks.
217219 The shape of the reference mask is N_labels x H x W, where H and W are the same as in the input image.
218220 """
219221 if boxes is None and points is None and polygons is None :
@@ -317,26 +319,32 @@ def infer(
317319
318320 Args:
319321 image (np.ndarray): HWC-shaped image
320- reference_features (VisualPromptingFeatures | None, optional): Reference features object obtained during previous learn() calls.
321- If not passed, object internal state is used, which reflects the last learn() call. Defaults to None.
322- apply_masks_refinement (bool, optional): Flag controlling additional refinement stage on inference. Once enabled, decoder will
323- be launched 2 extra times to refine the masks obtained with the first decoder call. Defaults to True.
322+ reference_features (VisualPromptingFeatures | None, optional): Reference features object obtained during
323+ previous learn() calls. If not passed, object internal state is used, which reflects the last learn()
324+ call. Defaults to None.
325+ apply_masks_refinement (bool, optional): Flag controlling additional refinement stage on inference.
326+ Once enabled, decoder will be launched 2 extra times to refine the masks obtained with the first decoder
327+ call. Defaults to True.
324328
325329 Returns:
326- ZSLVisualPromptingResult: Mapping label -> predicted mask. Each mask object contains a list of binary masks, and a list of
327- related prompts. Each binary mask corresponds to one prompt point. Class mask can be obtained by applying OR operation to all
328- mask corresponding to one label.
330+ ZSLVisualPromptingResult: Mapping label -> predicted mask. Each mask object contains a list of binary masks,
331+ and a list of related prompts. Each binary mask corresponds to one prompt point. Class mask can be
332+ obtained by applying OR operation to all mask corresponding to one label.
329333 """
330334 if reference_features is None :
331335 if self ._reference_features is None :
332336 raise RuntimeError (
333- "Reference features are not defined. This parameter can be passed via SAMLearnableVisualPrompter constructor, or as an argument of infer() method" ,
337+ (
338+ "Reference features are not defined. This parameter can be passed via "
339+ "SAMLearnableVisualPrompter constructor, or as an argument of infer() method"
340+ ),
334341 )
335342 reference_feats = self ._reference_features
336343
337344 if self ._used_indices is None :
338345 raise RuntimeError (
339- "Used indices are not defined. This parameter can be passed via SAMLearnableVisualPrompter constructor, or as an argument of infer() method" ,
346+ "Used indices are not defined. This parameter can be passed via "
347+ "SAMLearnableVisualPrompter constructor, or as an argument of infer() method"
340348 )
341349 used_idx = self ._used_indices
342350 else :
0 commit comments