Skip to content

Commit 8d52cdc

Browse files
Fix the prompt generator test
1 parent 0b17190 commit 8d52cdc

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/test_prompt_generators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_point_prompt_generator(self):
5353
generator = PointAndBoxPromptGenerator(n_pos, n_neg, dilation_strength=4)
5454
for label_id in label_ids:
5555
center, box = centers.get(label_id), boxes.get(label_id)
56-
coords, point_labels, _, _ = generator(labels, label_id, center, box)
56+
coords, point_labels, _, _ = generator(labels, label_id, box, center)
5757
coords_ = (np.array([int(coo[0]) for coo in coords]),
5858
np.array([int(coo[1]) for coo in coords]))
5959
mask = labels == label_id
@@ -80,7 +80,7 @@ def test_box_prompt_generator(self):
8080

8181
for label_id in label_ids:
8282
center, box_ = centers.get(label_id), boxes.get(label_id)
83-
_, _, box, _ = generator(labels, label_id, center, box_)
83+
_, _, box, _ = generator(labels, label_id, box_, center)
8484
coords = np.where(labels == label_id)
8585
expected_box = [coo.min() for coo in coords] + [coo.max() + 1 for coo in coords]
8686
self.assertEqual(expected_box, list(box[0]))

0 commit comments

Comments
 (0)