diff --git a/modules/impact/core.py b/modules/impact/core.py index fd15dbdf..7e53a586 100644 --- a/modules/impact/core.py +++ b/modules/impact/core.py @@ -468,25 +468,10 @@ def sam_predict(predictor, points, plabs, bbox, threshold): cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) - total_masks = [] - - selected = False - max_score = 0 - for idx in range(len(scores)): - if scores[idx] > max_score: - max_score = scores[idx] - max_mask = cur_masks[idx] - - if scores[idx] >= threshold: - selected = True - total_masks.append(cur_masks[idx]) - else: - pass - - if not selected: - total_masks.append(max_mask) - - return total_masks + # take all 3 masks predict returns, or take none + if any([score >= threshold for score in scores]): + return [m for m in cur_masks] + return [] def make_sam_mask(sam_model, segs, image, detection_hint, dilation, @@ -606,7 +591,7 @@ def make_sam_mask(sam_model, segs, image, detection_hint, dilation, mask = dilate_mask(mask.cpu().numpy(), dilation) mask = torch.from_numpy(mask) else: - mask = torch.zeros((8, 8), dtype=torch.float32, device="cpu") # empty mask + mask = torch.zeros((image.shape[0], image.shape[1]), dtype=torch.float32, device="cpu") # empty mask mask = utils.make_3d_mask(mask) return mask