@@ -53,7 +53,7 @@ def test_point_prompt_generator(self):
5353 generator = PointAndBoxPromptGenerator (n_pos , n_neg , dilation_strength = 4 )
5454 for label_id in label_ids :
5555 center , box = centers .get (label_id ), boxes .get (label_id )
56- coords , point_labels , _ , _ = generator (labels , label_id , center , box )
56+ coords , point_labels , _ , _ = generator (labels , label_id , box , center )
5757 coords_ = (np .array ([int (coo [0 ]) for coo in coords ]),
5858 np .array ([int (coo [1 ]) for coo in coords ]))
5959 mask = labels == label_id
@@ -80,7 +80,7 @@ def test_box_prompt_generator(self):
8080
8181 for label_id in label_ids :
8282 center , box_ = centers .get (label_id ), boxes .get (label_id )
83- _ , _ , box , _ = generator (labels , label_id , center , box_ )
83+ _ , _ , box , _ = generator (labels , label_id , box_ , center )
8484 coords = np .where (labels == label_id )
8585 expected_box = [coo .min () for coo in coords ] + [coo .max () + 1 for coo in coords ]
8686 self .assertEqual (expected_box , list (box [0 ]))
0 commit comments