@@ -17,41 +17,41 @@ def test_point_prompt_generator(self):
1717 from micro_sam .util import get_cell_center_coordinates
1818
1919 labels = self ._get_test_data ()
20- label_ids = np .unique (labels )
20+ label_ids = np .unique (labels )[ 1 :]
2121
2222 centers , boxes = get_cell_center_coordinates (labels )
2323
24- test_point_pairs = [(1 , 0 ), (1 , 1 ), (2 , 4 ), (3 , 9 )]
24+ test_point_pairs = [(4 , 3 )] # [( 1, 0), (1, 1), (2, 4), (3, 9)]
2525 for (n_pos , n_neg ) in test_point_pairs :
2626 generator = PointPromptGenerator (n_pos , n_neg , dilation_strength = 4 )
27- for label_id in label_ids [: - 1 ]: # we dodge the last element due to indexing issues for last val
28- center , box = centers [ label_id ] , boxes [ label_id ] # we start @id0 hence first mask prompts are accessed
27+ for label_id in label_ids :
28+ center , box = centers . get ( label_id ) , boxes . get ( label_id )
2929 coords , point_labels , _ , _ = generator (labels , label_id , center , box )
3030 coords_ = (np .array ([int (coo [0 ]) for coo in coords ]),
3131 np .array ([int (coo [1 ]) for coo in coords ]))
32- mask = labels == label_id + 1 # here we start @id0 hence adding 1 (to ignore background)
32+ mask = labels == label_id
3333 expected_labels = mask [coords_ ]
3434 agree = (point_labels == expected_labels )
3535 # DEBUG: check the points in napari if they don't match
3636 if not agree .all ():
3737 print (n_pos , n_neg )
38- # import napari
39- # v = napari.Viewer()
40- # v.add_image(mask)
41- # prompts = v.add_points(
42- # data=np.array(coords),
43- # name="prompts",
44- # properties={"label": point_labels},
45- # edge_color="label",
46- # edge_color_cycle=["#00FF00", "#FF0000"],
47- # symbol="o",
48- # face_color="transparent",
49- # edge_width=0.5,
50- # size=5,
51- # ndim=2
52- # ) # this function helps to view the (colored) background/foreground points
53- # prompts.edge_color_mode = "cycle"
54- # napari.run()
38+ import napari
39+ v = napari .Viewer ()
40+ v .add_image (mask )
41+ prompts = v .add_points (
42+ data = np .array (coords ),
43+ name = "prompts" ,
44+ properties = {"label" : point_labels },
45+ edge_color = "label" ,
46+ edge_color_cycle = ["#00FF00" , "#FF0000" ],
47+ symbol = "o" ,
48+ face_color = "transparent" ,
49+ edge_width = 0.5 ,
50+ size = 5 ,
51+ ndim = 2
52+ ) # this function helps to view the (colored) background/foreground points
53+ prompts .edge_color_mode = "cycle"
54+ napari .run ()
5555
5656 self .assertTrue (agree .all ())
5757
0 commit comments