Skip to content

Commit 53a675c

Browse files
committed
Update tests for Prompt Generators
1 parent 1e7c490 commit 53a675c

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

test/test_prompt_generators.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from skimage.data import binary_blobs
55
from skimage.measure import label
66

7+
COLOR_CYCLE = ["#00FF00", "#FF0000"]
8+
79

810
class TestPromptGenerators(unittest.TestCase):
911

@@ -17,18 +19,22 @@ def test_point_prompt_generator(self):
1719
from micro_sam.util import get_cell_center_coordinates
1820

1921
labels = self._get_test_data()
20-
label_ids = np.unique(labels)[1:]
21-
centers, boxes = get_cell_center_coordinates(labels)
22+
label_ids = np.unique(labels)
23+
24+
mode = "v" # this gets center points using vigra - regionprops throws error for asymmetrical blobs
25+
centers, boxes = get_cell_center_coordinates(labels, mode=mode)
26+
if mode == "v":
27+
centers.pop(0) # to avoid the background element
2228

2329
test_point_pairs = [(1, 0), (1, 1), (2, 4), (3, 9)]
2430
for (n_pos, n_neg) in test_point_pairs:
2531
generator = PointPromptGenerator(n_pos, n_neg, dilation_strength=4)
26-
for label_id in label_ids:
27-
center, box = centers[label_id], boxes[label_id]
32+
for label_id in label_ids[:-1]: # we dodge the last element due to indexing issues for last val
33+
center, box = centers[label_id], boxes[label_id] # we start @id0 hence first mask prompts are accessed
2834
coords, point_labels, _, _ = generator(labels, label_id, center, box)
2935
coords_ = (np.array([int(coo[0]) for coo in coords]),
3036
np.array([int(coo[1]) for coo in coords]))
31-
mask = labels == label_id
37+
mask = labels == label_id + 1 # here we start @id0 hence adding 1 (to ignore background)
3238
expected_labels = mask[coords_]
3339
agree = (point_labels == expected_labels)
3440
# DEBUG: check the points in napari if they don't match
@@ -37,8 +43,21 @@ def test_point_prompt_generator(self):
3743
# import napari
3844
# v = napari.Viewer()
3945
# v.add_image(mask)
40-
# v.add_points(coords)
46+
# prompts = v.add_points(
47+
# data=np.array(coords),
48+
# name="prompts",
49+
# properties={"label": point_labels},
50+
# edge_color="label",
51+
# edge_color_cycle=COLOR_CYCLE,
52+
# symbol="o",
53+
# face_color="transparent",
54+
# edge_width=0.5,
55+
# size=5,
56+
# ndim=2
57+
# ) # this function helps to view the (colored) background/foreground points
58+
# prompts.edge_color_mode = "cycle"
4159
# napari.run()
60+
4261
self.assertTrue(agree.all())
4362

4463

0 commit comments

Comments
 (0)