@@ -12,18 +12,45 @@ def _get_test_data(self):
1212 labels = label (data )
1313 return labels
1414
15+ def _debug (self , mask , center , box , coords , point_labels ):
16+ import napari
17+
18+ v = napari .Viewer ()
19+ v .add_image (mask )
20+ v .add_points ([center ], name = "center" )
21+ v .add_shapes (
22+ [np .array (
23+ [[box [0 ], box [1 ]], [box [2 ], box [3 ]]]
24+ )],
25+ shape_type = "rectangle"
26+ )
27+ prompts = v .add_points (
28+ data = np .array (coords ),
29+ name = "prompts" ,
30+ properties = {"label" : point_labels },
31+ edge_color = "label" ,
32+ edge_color_cycle = ["#00FF00" , "#FF0000" ],
33+ symbol = "o" ,
34+ face_color = "transparent" ,
35+ edge_width = 0.5 ,
36+ size = 5 ,
37+ ndim = 2
38+ ) # this function helps to view the (colored) background/foreground points
39+ prompts .edge_color_mode = "cycle"
40+ napari .run ()
41+
1542 def test_point_prompt_generator (self ):
16- from micro_sam .prompt_generators import PointPromptGenerator
43+ from micro_sam .prompt_generators import PointAndBoxPromptGenerator
1744 from micro_sam .util import get_cell_center_coordinates
1845
1946 labels = self ._get_test_data ()
2047 label_ids = np .unique (labels )[1 :]
2148
2249 centers , boxes = get_cell_center_coordinates (labels )
2350
24- test_point_pairs = [(4 , 3 )] # [ (1, 0 ), (1, 1 ), (2, 4), (3, 9)]
51+ test_point_pairs = [(1 , 0 ), (1 , 1 ), (4 , 3 ), (2 , 4 ), (3 , 9 ), ( 13 , 27 )]
2552 for (n_pos , n_neg ) in test_point_pairs :
26- generator = PointPromptGenerator (n_pos , n_neg , dilation_strength = 4 )
53+ generator = PointAndBoxPromptGenerator (n_pos , n_neg , dilation_strength = 4 )
2754 for label_id in label_ids :
2855 center , box = centers .get (label_id ), boxes .get (label_id )
2956 coords , point_labels , _ , _ = generator (labels , label_id , center , box )
@@ -32,26 +59,12 @@ def test_point_prompt_generator(self):
3259 mask = labels == label_id
3360 expected_labels = mask [coords_ ]
3461 agree = (point_labels == expected_labels )
62+
3563 # DEBUG: check the points in napari if they don't match
36- if not agree .all ():
64+ debug = False
65+ if not agree .all () and debug :
3766 print (n_pos , n_neg )
38- import napari
39- v = napari .Viewer ()
40- v .add_image (mask )
41- prompts = v .add_points (
42- data = np .array (coords ),
43- name = "prompts" ,
44- properties = {"label" : point_labels },
45- edge_color = "label" ,
46- edge_color_cycle = ["#00FF00" , "#FF0000" ],
47- symbol = "o" ,
48- face_color = "transparent" ,
49- edge_width = 0.5 ,
50- size = 5 ,
51- ndim = 2
52- ) # this function helps to view the (colored) background/foreground points
53- prompts .edge_color_mode = "cycle"
54- napari .run ()
67+ self ._debug (mask , center , box , coords , point_labels )
5568
5669 self .assertTrue (agree .all ())
5770
0 commit comments