Skip to content

Commit 4ee1876

Browse files
Add test for the point prompt generator WIP
1 parent ea84438 commit 4ee1876

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

test/test_prompt_generators.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
import numpy as np
3+
4+
from skimage.data import binary_blobs
5+
from skimage.measure import label
6+
7+
8+
class TestPromptGenerators(unittest.TestCase):
9+
10+
def _get_test_data(self):
11+
data = binary_blobs(length=256)
12+
labels = label(data)
13+
return labels
14+
15+
def test_point_prompt_generator(self):
16+
from micro_sam.prompt_generators import PointPromptGenerator
17+
from micro_sam.util import get_cell_center_coordinates
18+
19+
labels = self._get_test_data()
20+
label_ids = np.unique(labels)[1:]
21+
centers, boxes = get_cell_center_coordinates(labels)
22+
23+
test_point_pairs = [(1, 0), (1, 1), (2, 4), (3, 9)]
24+
for (n_pos, n_neg) in test_point_pairs:
25+
generator = PointPromptGenerator(n_pos, n_neg, dilation_strength=4)
26+
for label_id in label_ids:
27+
center, box = centers[label_id], boxes[label_id]
28+
coords, point_labels, _, _ = generator(labels, label_id, center, box)
29+
coords_ = (np.array([int(coo[0]) for coo in coords]),
30+
np.array([int(coo[1]) for coo in coords]))
31+
mask = labels == label_id
32+
expected_labels = mask[coords_]
33+
agree = (point_labels == expected_labels)
34+
# DEBUG: check the points in napari if they don't match
35+
if not agree.all():
36+
print(n_pos, n_neg)
37+
# import napari
38+
# v = napari.Viewer()
39+
# v.add_image(mask)
40+
# v.add_points(coords)
41+
# napari.run()
42+
self.assertTrue(agree.all())
43+
44+
45+
if __name__ == "__main__":
46+
unittest.main()

0 commit comments

Comments
 (0)