Skip to content

Commit 5ec88e1

Browse files
Implement test for non-square segmentation form mask (currently failing)
1 parent 7b6166a commit 5ec88e1

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

test/test_segment_from_prompts.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88

99
class TestSegmentFromPrompts(unittest.TestCase):
10-
def _get_input(self):
11-
shape = (256, 256)
10+
def _get_input(self, shape=(256, 256)):
1211
mask = np.zeros(shape, dtype="uint8")
13-
circle = disk((128, 128), radius=20, shape=shape)
12+
center = tuple(sh // 2 for sh in shape)
13+
circle = disk(center, radius=20, shape=shape)
1414
mask[circle] = 1
1515
image = mask * 255
1616
return mask, image
@@ -33,24 +33,32 @@ def test_segment_from_points(self):
3333
predicted = segment_from_points(predictor, points, labels)
3434
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
3535

36-
def test_segment_from_mask(self):
36+
def _test_segment_from_mask(self, shape=(256, 256)):
3737
from micro_sam.segment_from_prompts import segment_from_mask
3838

39-
mask, image = self._get_input()
39+
mask, image = self._get_input(shape)
4040
predictor = self._get_model(image)
4141

4242
# with mask and bounding box (default setting)
4343
predicted = segment_from_mask(predictor, mask)
4444
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
4545

46-
# with bounding box (default setting)
46+
# with bounding box
4747
predicted = segment_from_mask(predictor, mask, use_mask=False, use_box=True)
4848
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
4949

50-
# with bounding box (default setting)
50+
# with mask
5151
predicted = segment_from_mask(predictor, mask, use_mask=True, use_box=False)
5252
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
5353

54+
def test_segment_from_mask(self):
55+
self._test_segment_from_mask()
56+
57+
# FIXME this fails due to shape mismatch in the masks
58+
@unittest.expectedFailure
59+
def test_segment_from_mask_non_square(self):
60+
self._test_segment_from_mask((256, 384))
61+
5462
def test_segment_from_box(self):
5563
from micro_sam.segment_from_prompts import segment_from_box
5664

0 commit comments

Comments
 (0)