Skip to content

Commit 801a330

Browse files
committed
WIP Update tests for prompt generators
1 parent 5507834 commit 801a330

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

test/test_prompt_generators.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,41 +17,41 @@ def test_point_prompt_generator(self):
1717
from micro_sam.util import get_cell_center_coordinates
1818

1919
labels = self._get_test_data()
20-
label_ids = np.unique(labels)
20+
label_ids = np.unique(labels)[1:]
2121

2222
centers, boxes = get_cell_center_coordinates(labels)
2323

24-
test_point_pairs = [(1, 0), (1, 1), (2, 4), (3, 9)]
24+
test_point_pairs = [(4, 3)] # [(1, 0), (1, 1), (2, 4), (3, 9)]
2525
for (n_pos, n_neg) in test_point_pairs:
2626
generator = PointPromptGenerator(n_pos, n_neg, dilation_strength=4)
27-
for label_id in label_ids[:-1]: # we dodge the last element due to indexing issues for last val
28-
center, box = centers[label_id], boxes[label_id] # we start @id0 hence first mask prompts are accessed
27+
for label_id in label_ids:
28+
center, box = centers.get(label_id), boxes.get(label_id)
2929
coords, point_labels, _, _ = generator(labels, label_id, center, box)
3030
coords_ = (np.array([int(coo[0]) for coo in coords]),
3131
np.array([int(coo[1]) for coo in coords]))
32-
mask = labels == label_id + 1 # here we start @id0 hence adding 1 (to ignore background)
32+
mask = labels == label_id
3333
expected_labels = mask[coords_]
3434
agree = (point_labels == expected_labels)
3535
# DEBUG: check the points in napari if they don't match
3636
if not agree.all():
3737
print(n_pos, n_neg)
38-
# import napari
39-
# v = napari.Viewer()
40-
# v.add_image(mask)
41-
# prompts = v.add_points(
42-
# data=np.array(coords),
43-
# name="prompts",
44-
# properties={"label": point_labels},
45-
# edge_color="label",
46-
# edge_color_cycle=["#00FF00", "#FF0000"],
47-
# symbol="o",
48-
# face_color="transparent",
49-
# edge_width=0.5,
50-
# size=5,
51-
# ndim=2
52-
# ) # this function helps to view the (colored) background/foreground points
53-
# prompts.edge_color_mode = "cycle"
54-
# napari.run()
38+
import napari
39+
v = napari.Viewer()
40+
v.add_image(mask)
41+
prompts = v.add_points(
42+
data=np.array(coords),
43+
name="prompts",
44+
properties={"label": point_labels},
45+
edge_color="label",
46+
edge_color_cycle=["#00FF00", "#FF0000"],
47+
symbol="o",
48+
face_color="transparent",
49+
edge_width=0.5,
50+
size=5,
51+
ndim=2
52+
) # this function helps to view the (colored) background/foreground points
53+
prompts.edge_color_mode = "cycle"
54+
napari.run()
5555

5656
self.assertTrue(agree.all())
5757

0 commit comments

Comments
 (0)