77
88
99class TestSegmentFromPrompts (unittest .TestCase ):
10- def _get_input (self , shape = (256 , 256 )):
10+ @staticmethod
11+ def _get_input (shape = (256 , 256 )):
1112 mask = np .zeros (shape , dtype = "uint8" )
1213 center = tuple (sh // 2 for sh in shape )
1314 circle = disk (center , radius = 20 , shape = shape )
1415 mask [circle ] = 1
1516 image = mask * 255
1617 return mask , image
1718
18- def _get_model (self , image ):
19+ @staticmethod
20+ def _get_model (image ):
1921 predictor = util .get_sam_model (model_type = "vit_b" )
2022 image_embeddings = util .precompute_image_embeddings (predictor , image )
2123 util .set_precomputed (predictor , image_embeddings )
2224 return predictor
2325
26+ # we compute the default mask and predictor once for the class
27+ # so that we don't have to precompute it every time
28+ @classmethod
29+ def setUpClass (cls ):
30+ cls .mask , cls .image = cls ._get_input ()
31+ cls .predictor = cls ._get_model (cls .image )
32+
2433 def test_segment_from_points (self ):
2534 from micro_sam .segment_from_prompts import segment_from_points
2635
27- mask , image = self ._get_input ()
28- predictor = self ._get_model (image )
29-
3036 points = np .array ([[128 , 128 ], [64 , 64 ], [192 , 192 ], [64 , 192 ], [192 , 64 ]])
3137 labels = np .array ([1 , 0 , 0 , 0 , 0 ])
3238
33- predicted = segment_from_points (predictor , points , labels )
34- self .assertGreater (util .compute_iou (mask , predicted ), 0.9 )
39+ predicted = segment_from_points (self . predictor , points , labels )
40+ self .assertGreater (util .compute_iou (self . mask , predicted ), 0.9 )
3541
3642 def _test_segment_from_mask (self , shape = (256 , 256 ), use_mask = True ):
3743 from micro_sam .segment_from_prompts import segment_from_mask
3844
39- mask , image = self ._get_input (shape )
40- predictor = self ._get_model (image )
45+ if shape == (256 , 256 ):
46+ mask , image = self .mask , self .image
47+ predictor = self .predictor
48+ else :
49+ mask , image = self ._get_input (shape )
50+ predictor = self ._get_model (image )
4151
4252 # with mask and bounding box (default setting)
4353 if use_mask :
@@ -64,12 +74,19 @@ def test_segment_from_mask_non_square(self):
6474 def test_segment_from_box (self ):
6575 from micro_sam .segment_from_prompts import segment_from_box
6676
67- mask , image = self ._get_input ()
68- predictor = self ._get_model (image )
77+ box = np .array ([106 , 106 , 150 , 150 ])
78+ predicted = segment_from_box (self .predictor , box )
79+ self .assertGreater (util .compute_iou (self .mask , predicted ), 0.9 )
80+
81+ def test_segment_from_box_and_points (self ):
82+ from micro_sam .segment_from_prompts import segment_from_box_and_points
6983
7084 box = np .array ([106 , 106 , 150 , 150 ])
71- predicted = segment_from_box (predictor , box )
72- self .assertGreater (util .compute_iou (mask , predicted ), 0.9 )
85+ points = np .array ([[128 , 128 ], [64 , 64 ], [192 , 192 ], [64 , 192 ], [192 , 64 ]])
86+ labels = np .array ([1 , 0 , 0 , 0 , 0 ])
87+
88+ predicted = segment_from_box_and_points (self .predictor , box , points , labels )
89+ self .assertGreater (util .compute_iou (self .mask , predicted ), 0.9 )
7390
7491
7592if __name__ == "__main__" :
0 commit comments