77
88
99class TestSegmentFromPrompts (unittest .TestCase ):
10- def _get_input (self ):
11- shape = (256 , 256 )
10+ def _get_input (self , shape = (256 , 256 )):
1211 mask = np .zeros (shape , dtype = "uint8" )
13- circle = disk ((128 , 128 ), radius = 20 , shape = shape )
12+ center = tuple (sh // 2 for sh in shape )
13+ circle = disk (center , radius = 20 , shape = shape )
1414 mask [circle ] = 1
1515 image = mask * 255
1616 return mask , image
@@ -33,24 +33,32 @@ def test_segment_from_points(self):
3333 predicted = segment_from_points (predictor , points , labels )
3434 self .assertGreater (util .compute_iou (mask , predicted ), 0.9 )
3535
36- def test_segment_from_mask (self ):
36+ def _test_segment_from_mask (self , shape = ( 256 , 256 ) ):
3737 from micro_sam .segment_from_prompts import segment_from_mask
3838
39- mask , image = self ._get_input ()
39+ mask , image = self ._get_input (shape )
4040 predictor = self ._get_model (image )
4141
4242 # with mask and bounding box (default setting)
4343 predicted = segment_from_mask (predictor , mask )
4444 self .assertGreater (util .compute_iou (mask , predicted ), 0.9 )
4545
46- # with bounding box (default setting)
46+ # with bounding box
4747 predicted = segment_from_mask (predictor , mask , use_mask = False , use_box = True )
4848 self .assertGreater (util .compute_iou (mask , predicted ), 0.9 )
4949
50- # with bounding box (default setting)
50+ # with mask
5151 predicted = segment_from_mask (predictor , mask , use_mask = True , use_box = False )
5252 self .assertGreater (util .compute_iou (mask , predicted ), 0.9 )
5353
54+ def test_segment_from_mask (self ):
55+ self ._test_segment_from_mask ()
56+
57+ # FIXME this fails due to shape mismatch in the masks
58+ @unittest .expectedFailure
59+ def test_segment_from_mask_non_square (self ):
60+ self ._test_segment_from_mask ((256 , 384 ))
61+
5462 def test_segment_from_box (self ):
5563 from micro_sam .segment_from_prompts import segment_from_box
5664
0 commit comments