Skip to content

Commit ebfc4be

Browse files
Merge pull request #16 from computational-cell-analytics/test-prompts
Add test for the point prompt generator
2 parents e220a84 + 7632aec commit ebfc4be

File tree

3 files changed

+96
-4
lines changed

3 files changed

+96
-4
lines changed

micro_sam/prompt_generators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __call__(self, gt, gt_id, center_coordinates, bbox_coordinates):
3232
if self.get_box_prompts:
3333
bbox_list = [bbox_coordinates]
3434

35-
object_mask = gt == gt_id + 1 # alloting a label id to obtain the coordinates of desired seeds
35+
object_mask = gt == gt_id
3636

3737
# getting the additional positive points by randomly sampling points from this mask
3838
n_positive_remaining = self.n_positive_points - 1

micro_sam/util.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def compute_iou(mask1, mask2):
295295
return iou
296296

297297

298-
def get_cell_center_coordinates(gt, mode="p"):
298+
def get_cell_center_coordinates(gt, mode="v"):
299299
"""
300300
Returns the center coordinates of the foreground instances in the ground-truth
301301
"""
@@ -304,12 +304,14 @@ def get_cell_center_coordinates(gt, mode="p"):
304304
properties = regionprops(gt)
305305

306306
if mode == "p":
307-
center_coordinates = [prop.centroid for prop in properties]
307+
center_coordinates = {prop.label: prop.centroid for prop in properties}
308308
elif mode == "v":
309309
center_coordinates = vigra.filters.eccentricityCenters(gt.astype('float32'))
310+
center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
310311

311-
bbox_coordinates = [prop.bbox for prop in properties]
312+
bbox_coordinates = {prop.label: prop.bbox for prop in properties}
312313

314+
assert len(bbox_coordinates) == len(center_coordinates)
313315
return center_coordinates, bbox_coordinates
314316

315317

test/test_prompt_generators.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import unittest
2+
import numpy as np
3+
4+
from skimage.data import binary_blobs
5+
from skimage.measure import label
6+
7+
8+
class TestPromptGenerators(unittest.TestCase):
9+
10+
def _get_test_data(self):
11+
data = binary_blobs(length=256)
12+
labels = label(data)
13+
return labels
14+
15+
def _debug(self, mask, center, box, coords, point_labels):
16+
import napari
17+
18+
v = napari.Viewer()
19+
v.add_image(mask)
20+
v.add_points([center], name="center")
21+
v.add_shapes(
22+
[np.array(
23+
[[box[0], box[1]], [box[2], box[3]]]
24+
)],
25+
shape_type="rectangle"
26+
)
27+
prompts = v.add_points(
28+
data=np.array(coords),
29+
name="prompts",
30+
properties={"label": point_labels},
31+
edge_color="label",
32+
edge_color_cycle=["#00FF00", "#FF0000"],
33+
symbol="o",
34+
face_color="transparent",
35+
edge_width=0.5,
36+
size=5,
37+
ndim=2
38+
) # this function helps to view the (colored) background/foreground points
39+
prompts.edge_color_mode = "cycle"
40+
napari.run()
41+
42+
def test_point_prompt_generator(self):
43+
from micro_sam.prompt_generators import PointAndBoxPromptGenerator
44+
from micro_sam.util import get_cell_center_coordinates
45+
46+
labels = self._get_test_data()
47+
label_ids = np.unique(labels)[1:]
48+
49+
centers, boxes = get_cell_center_coordinates(labels)
50+
51+
test_point_pairs = [(1, 0), (1, 1), (4, 3), (2, 4), (3, 9), (13, 27)]
52+
for (n_pos, n_neg) in test_point_pairs:
53+
generator = PointAndBoxPromptGenerator(n_pos, n_neg, dilation_strength=4)
54+
for label_id in label_ids:
55+
center, box = centers.get(label_id), boxes.get(label_id)
56+
coords, point_labels, _, _ = generator(labels, label_id, center, box)
57+
coords_ = (np.array([int(coo[0]) for coo in coords]),
58+
np.array([int(coo[1]) for coo in coords]))
59+
mask = labels == label_id
60+
expected_labels = mask[coords_]
61+
agree = (point_labels == expected_labels)
62+
63+
# DEBUG: check the points in napari if they don't match
64+
debug = False
65+
if not agree.all() and debug:
66+
print(n_pos, n_neg)
67+
self._debug(mask, center, box, coords, point_labels)
68+
69+
self.assertTrue(agree.all())
70+
71+
def test_box_prompt_generator(self):
72+
from micro_sam.prompt_generators import PointAndBoxPromptGenerator
73+
from micro_sam.util import get_cell_center_coordinates
74+
75+
labels = self._get_test_data()
76+
label_ids = np.unique(labels)[1:]
77+
78+
centers, boxes = get_cell_center_coordinates(labels)
79+
generator = PointAndBoxPromptGenerator(0, 0, dilation_strength=0, get_point_prompts=False, get_box_prompts=True)
80+
81+
for label_id in label_ids:
82+
center, box_ = centers.get(label_id), boxes.get(label_id)
83+
_, _, box, _ = generator(labels, label_id, center, box_)
84+
coords = np.where(labels == label_id)
85+
expected_box = [coo.min() for coo in coords] + [coo.max() + 1 for coo in coords]
86+
self.assertEqual(expected_box, list(box[0]))
87+
88+
89+
if __name__ == "__main__":
90+
unittest.main()

0 commit comments

Comments
 (0)