44from skimage .data import binary_blobs
55from skimage .measure import label
66
7+ COLOR_CYCLE = ["#00FF00" , "#FF0000" ]
8+
79
810class TestPromptGenerators (unittest .TestCase ):
911
@@ -17,18 +19,22 @@ def test_point_prompt_generator(self):
1719 from micro_sam .util import get_cell_center_coordinates
1820
1921 labels = self ._get_test_data ()
20- label_ids = np .unique (labels )[1 :]
21- centers , boxes = get_cell_center_coordinates (labels )
22+ label_ids = np .unique (labels )
23+
24+ mode = "v" # this gets center points using vigra - regionprops throws error for asymmetrical blobs
25+ centers , boxes = get_cell_center_coordinates (labels , mode = mode )
26+ if mode == "v" :
27+ centers .pop (0 ) # to avoid the background element
2228
2329 test_point_pairs = [(1 , 0 ), (1 , 1 ), (2 , 4 ), (3 , 9 )]
2430 for (n_pos , n_neg ) in test_point_pairs :
2531 generator = PointPromptGenerator (n_pos , n_neg , dilation_strength = 4 )
26- for label_id in label_ids :
27- center , box = centers [label_id ], boxes [label_id ]
32+ for label_id in label_ids [: - 1 ]: # we dodge the last element due to indexing issues for last val
33+ center , box = centers [label_id ], boxes [label_id ] # we start @id0 hence first mask prompts are accessed
2834 coords , point_labels , _ , _ = generator (labels , label_id , center , box )
2935 coords_ = (np .array ([int (coo [0 ]) for coo in coords ]),
3036 np .array ([int (coo [1 ]) for coo in coords ]))
31- mask = labels == label_id
37+ mask = labels == label_id + 1 # here we start @id0 hence adding 1 (to ignore background)
3238 expected_labels = mask [coords_ ]
3339 agree = (point_labels == expected_labels )
3440 # DEBUG: check the points in napari if they don't match
@@ -37,8 +43,21 @@ def test_point_prompt_generator(self):
3743 # import napari
3844 # v = napari.Viewer()
3945 # v.add_image(mask)
40- # v.add_points(coords)
46+ # prompts = v.add_points(
47+ # data=np.array(coords),
48+ # name="prompts",
49+ # properties={"label": point_labels},
50+ # edge_color="label",
51+ # edge_color_cycle=COLOR_CYCLE,
52+ # symbol="o",
53+ # face_color="transparent",
54+ # edge_width=0.5,
55+ # size=5,
56+ # ndim=2
57+ # ) # this function helps to view the (colored) background/foreground points
58+ # prompts.edge_color_mode = "cycle"
4159 # napari.run()
60+
4261 self .assertTrue (agree .all ())
4362
4463
0 commit comments