Skip to content

Commit 65833ad

Browse files
Update the instance segmentation test
1 parent c987e91 commit 65833ad

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

test/test_instance_segmentation.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
class TestInstanceSegmentation(unittest.TestCase):
1313
embedding_path = "./tmp_embeddings.zarr"
14-
tile_shape = (576, 576)
15-
halo = (64, 64)
14+
tile_shape = (512, 512)
15+
halo = (96, 96)
1616

1717
# create an input image with three objects
1818
@staticmethod
@@ -24,13 +24,13 @@ def write_object(center, radius):
2424
mask[circle] = 1
2525

2626
center = tuple(sh // 4 for sh in shape)
27-
write_object(center, radius=19)
27+
write_object(center, radius=29)
2828

2929
center = tuple(sh // 2 for sh in shape)
30-
write_object(center, radius=27)
30+
write_object(center, radius=33)
3131

3232
center = tuple(3 * sh // 4 for sh in shape)
33-
write_object(center, radius=22)
33+
write_object(center, radius=35)
3434

3535
image = mask * 255
3636
mask = label(mask)
@@ -53,7 +53,6 @@ def setUpClass(cls):
5353
cls.predictor, cls.large_image, save_path=cls.embedding_path, tile_shape=cls.tile_shape, halo=cls.halo
5454
)
5555

56-
# remove temp embeddings if any
5756
@classmethod
5857
def tearDownClass(cls):
5958
try:
@@ -72,7 +71,7 @@ def test_automatic_mask_generator(self):
7271

7372
predicted = amg.generate()
7473
predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True)
75-
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
74+
self.assertGreater(matching(predicted, mask, threshold=0.75)["accuracy"], 0.99)
7675

7776
# check that regenerating the segmentation works
7877
predicted2 = amg.generate()
@@ -98,7 +97,7 @@ def test_embedding_mask_generator(self):
9897
predicted = amg.generate(pred_iou_thresh=0.96)
9998
predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True)
10099

101-
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
100+
self.assertGreater(matching(predicted, mask, threshold=0.75)["accuracy"], 0.99)
102101

103102
initial_seg = amg.get_initial_segmentation()
104103
self.assertEqual(initial_seg.shape, image.shape)
@@ -127,7 +126,7 @@ def test_tiled_embedding_mask_generator(self):
127126
predicted = amg.generate(pred_iou_thresh=0.96)
128127
initial_seg = amg.get_initial_segmentation()
129128

130-
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
129+
self.assertGreater(matching(predicted, mask, threshold=0.75)["accuracy"], 0.99)
131130
self.assertEqual(initial_seg.shape, image.shape)
132131

133132
predicted2 = amg.generate(pred_iou_thresh=0.96)
@@ -146,22 +145,24 @@ def test_tiled_automatic_mask_generator(self):
146145
mask, image = self.large_mask, self.large_image
147146
predictor, image_embeddings = self.predictor, self.tiled_embeddings
148147

149-
amg = TiledAutomaticMaskGenerator(predictor)
150-
amg.initialize(image, image_embeddings=image_embeddings)
151-
predicted = amg.generate(pred_iou_thresh=0.96)
148+
pred_iou_thresh = 0.75
149+
150+
amg = TiledAutomaticMaskGenerator(predictor, points_per_side=8)
151+
amg.initialize(image, image_embeddings=image_embeddings, verbose=False)
152+
predicted = amg.generate(pred_iou_thresh=pred_iou_thresh)
152153
predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True)
153-
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
154+
self.assertGreater(matching(predicted, mask, threshold=0.75)["accuracy"], 0.99)
154155

155-
predicted2 = amg.generate(pred_iou_thresh=0.96)
156-
predicted = mask_data_to_segmentation(predicted2, image.shape, with_background=True)
156+
predicted2 = amg.generate(pred_iou_thresh=pred_iou_thresh)
157+
predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True)
157158
self.assertTrue(np.array_equal(predicted, predicted2))
158159

159160
# check that serializing and reserializing the state works
160161
state = amg.get_state()
161162
amg = TiledAutomaticMaskGenerator(predictor)
162163
amg.set_state(state)
163-
predicted3 = amg.generate(pred_iou_thresh=0.96)
164-
predicted = mask_data_to_segmentation(predicted3, image.shape, with_background=True)
164+
predicted3 = amg.generate(pred_iou_thresh=pred_iou_thresh)
165+
predicted3 = mask_data_to_segmentation(predicted3, image.shape, with_background=True)
165166
self.assertTrue(np.array_equal(predicted, predicted3))
166167

167168

0 commit comments

Comments
 (0)