1111
1212class TestInstanceSegmentation (unittest .TestCase ):
1313 embedding_path = "./tmp_embeddings.zarr"
14- tile_shape = (576 , 576 )
15- halo = (64 , 64 )
14+ tile_shape = (512 , 512 )
15+ halo = (96 , 96 )
1616
1717 # create an input image with three objects
1818 @staticmethod
@@ -24,13 +24,13 @@ def write_object(center, radius):
2424 mask [circle ] = 1
2525
2626 center = tuple (sh // 4 for sh in shape )
27- write_object (center , radius = 19 )
27+ write_object (center , radius = 29 )
2828
2929 center = tuple (sh // 2 for sh in shape )
30- write_object (center , radius = 27 )
30+ write_object (center , radius = 33 )
3131
3232 center = tuple (3 * sh // 4 for sh in shape )
33- write_object (center , radius = 22 )
33+ write_object (center , radius = 35 )
3434
3535 image = mask * 255
3636 mask = label (mask )
@@ -53,7 +53,6 @@ def setUpClass(cls):
5353 cls .predictor , cls .large_image , save_path = cls .embedding_path , tile_shape = cls .tile_shape , halo = cls .halo
5454 )
5555
56- # remove temp embeddings if any
5756 @classmethod
5857 def tearDownClass (cls ):
5958 try :
@@ -72,7 +71,7 @@ def test_automatic_mask_generator(self):
7271
7372 predicted = amg .generate ()
7473 predicted = mask_data_to_segmentation (predicted , image .shape , with_background = True )
75- self .assertGreater (matching (predicted , mask , threshold = 0.75 )["precision " ], 0.99 )
74+ self .assertGreater (matching (predicted , mask , threshold = 0.75 )["accuracy " ], 0.99 )
7675
7776 # check that regenerating the segmentation works
7877 predicted2 = amg .generate ()
@@ -98,7 +97,7 @@ def test_embedding_mask_generator(self):
9897 predicted = amg .generate (pred_iou_thresh = 0.96 )
9998 predicted = mask_data_to_segmentation (predicted , image .shape , with_background = True )
10099
101- self .assertGreater (matching (predicted , mask , threshold = 0.75 )["precision " ], 0.99 )
100+ self .assertGreater (matching (predicted , mask , threshold = 0.75 )["accuracy " ], 0.99 )
102101
103102 initial_seg = amg .get_initial_segmentation ()
104103 self .assertEqual (initial_seg .shape , image .shape )
@@ -127,7 +126,7 @@ def test_tiled_embedding_mask_generator(self):
127126 predicted = amg .generate (pred_iou_thresh = 0.96 )
128127 initial_seg = amg .get_initial_segmentation ()
129128
130- self .assertGreater (matching (predicted , mask , threshold = 0.75 )["precision " ], 0.99 )
129+ self .assertGreater (matching (predicted , mask , threshold = 0.75 )["accuracy " ], 0.99 )
131130 self .assertEqual (initial_seg .shape , image .shape )
132131
133132 predicted2 = amg .generate (pred_iou_thresh = 0.96 )
@@ -146,22 +145,24 @@ def test_tiled_automatic_mask_generator(self):
146145 mask , image = self .large_mask , self .large_image
147146 predictor , image_embeddings = self .predictor , self .tiled_embeddings
148147
149- amg = TiledAutomaticMaskGenerator (predictor )
150- amg .initialize (image , image_embeddings = image_embeddings )
151- predicted = amg .generate (pred_iou_thresh = 0.96 )
148+ pred_iou_thresh = 0.75
149+
150+ amg = TiledAutomaticMaskGenerator (predictor , points_per_side = 8 )
151+ amg .initialize (image , image_embeddings = image_embeddings , verbose = False )
152+ predicted = amg .generate (pred_iou_thresh = pred_iou_thresh )
152153 predicted = mask_data_to_segmentation (predicted , image .shape , with_background = True )
153- self .assertGreater (matching (predicted , mask , threshold = 0.75 )["precision " ], 0.99 )
154+ self .assertGreater (matching (predicted , mask , threshold = 0.75 )["accuracy " ], 0.99 )
154155
155- predicted2 = amg .generate (pred_iou_thresh = 0.96 )
156- predicted = mask_data_to_segmentation (predicted2 , image .shape , with_background = True )
156+ predicted2 = amg .generate (pred_iou_thresh = pred_iou_thresh )
157+ predicted2 = mask_data_to_segmentation (predicted2 , image .shape , with_background = True )
157158 self .assertTrue (np .array_equal (predicted , predicted2 ))
158159
159160 # check that serializing and reserializing the state works
160161 state = amg .get_state ()
161162 amg = TiledAutomaticMaskGenerator (predictor )
162163 amg .set_state (state )
163- predicted3 = amg .generate (pred_iou_thresh = 0.96 )
164- predicted = mask_data_to_segmentation (predicted3 , image .shape , with_background = True )
164+ predicted3 = amg .generate (pred_iou_thresh = pred_iou_thresh )
165+ predicted3 = mask_data_to_segmentation (predicted3 , image .shape , with_background = True )
165166 self .assertTrue (np .array_equal (predicted , predicted3 ))
166167
167168
0 commit comments