Skip to content

Commit 3153467

Browse files
Refactor debug code in point generator test
1 parent 7a4236a commit 3153467

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

test/test_prompt_generators.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,45 @@ def _get_test_data(self):
1212
labels = label(data)
1313
return labels
1414

15+
def _debug(self, mask, center, box, coords, point_labels):
16+
import napari
17+
18+
v = napari.Viewer()
19+
v.add_image(mask)
20+
v.add_points([center], name="center")
21+
v.add_shapes(
22+
[np.array(
23+
[[box[0], box[1]], [box[2], box[3]]]
24+
)],
25+
shape_type="rectangle"
26+
)
27+
prompts = v.add_points(
28+
data=np.array(coords),
29+
name="prompts",
30+
properties={"label": point_labels},
31+
edge_color="label",
32+
edge_color_cycle=["#00FF00", "#FF0000"],
33+
symbol="o",
34+
face_color="transparent",
35+
edge_width=0.5,
36+
size=5,
37+
ndim=2
38+
) # this function helps to view the (colored) background/foreground points
39+
prompts.edge_color_mode = "cycle"
40+
napari.run()
41+
1542
def test_point_prompt_generator(self):
16-
from micro_sam.prompt_generators import PointPromptGenerator
43+
from micro_sam.prompt_generators import PointAndBoxPromptGenerator
1744
from micro_sam.util import get_cell_center_coordinates
1845

1946
labels = self._get_test_data()
2047
label_ids = np.unique(labels)[1:]
2148

2249
centers, boxes = get_cell_center_coordinates(labels)
2350

24-
test_point_pairs = [(4, 3)] # [(1, 0), (1, 1), (2, 4), (3, 9)]
51+
test_point_pairs = [(1, 0), (1, 1), (4, 3), (2, 4), (3, 9), (13, 27)]
2552
for (n_pos, n_neg) in test_point_pairs:
26-
generator = PointPromptGenerator(n_pos, n_neg, dilation_strength=4)
53+
generator = PointAndBoxPromptGenerator(n_pos, n_neg, dilation_strength=4)
2754
for label_id in label_ids:
2855
center, box = centers.get(label_id), boxes.get(label_id)
2956
coords, point_labels, _, _ = generator(labels, label_id, center, box)
@@ -32,26 +59,12 @@ def test_point_prompt_generator(self):
3259
mask = labels == label_id
3360
expected_labels = mask[coords_]
3461
agree = (point_labels == expected_labels)
62+
3563
# DEBUG: check the points in napari if they don't match
36-
if not agree.all():
64+
debug = False
65+
if not agree.all() and debug:
3766
print(n_pos, n_neg)
38-
import napari
39-
v = napari.Viewer()
40-
v.add_image(mask)
41-
prompts = v.add_points(
42-
data=np.array(coords),
43-
name="prompts",
44-
properties={"label": point_labels},
45-
edge_color="label",
46-
edge_color_cycle=["#00FF00", "#FF0000"],
47-
symbol="o",
48-
face_color="transparent",
49-
edge_width=0.5,
50-
size=5,
51-
ndim=2
52-
) # this function helps to view the (colored) background/foreground points
53-
prompts.edge_color_mode = "cycle"
54-
napari.run()
67+
self._debug(mask, center, box, coords, point_labels)
5568

5669
self.assertTrue(agree.all())
5770

0 commit comments

Comments
 (0)