Skip to content

Commit 02cede4

Browse files
authored
Fix masks merging in ZSL VPT (#200)
* Fix masks merging * Update refs for zsl * Del redundant mask clip
1 parent 32740e5 commit 02cede4

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

model_api/python/model_api/models/visual_prompting.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,7 @@ def learn(
286286
masks = _polygon_to_mask(inputs_decoder["polygon"], *original_shape)
287287
else:
288288
raise RuntimeError("Unsupported type of prompt")
289-
ref_mask[masks] += 1
290-
ref_mask = np.clip(ref_mask, 0, 1)
289+
ref_mask = np.where(masks, 1, ref_mask)
291290

292291
ref_feat: np.ndarray | None = None
293292
cur_default_threshold_reference = self._default_threshold_reference

tests/python/accuracy/public_scope.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@
406406
"test_data": [
407407
{
408408
"image": "coco128/images/train2017/000000000471.jpg",
409-
"reference": ["mask sum: 14991; [385.0, 315.0] iou: 0.930 [44.0, 205.0] iou: 0.665 [605.0, 224.0] iou: 0.653, mask sum: 248221; [374.0, 365.0] iou: 0.901 [335.0, 34.0] iou: 0.901 [354.0, 135.0] iou: 0.709"]
409+
"reference": ["mask sum: 108565; [385.0, 315.0] iou: 0.930 [335.0, 414.0] iou: 0.763 [44.0, 205.0] iou: 0.665 [605.0, 224.0] iou: 0.653, mask sum: 73920; [175.0, 215.0] iou: 0.781 [124.0, 165.0] iou: 0.651"]
410410
}
411411
]
412412
},

0 commit comments

Comments
 (0)