Skip to content

Commit 7632aec

Browse files
Add test for box prompt generator
1 parent 3153467 commit 7632aec

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

test/test_prompt_generators.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ def test_point_prompt_generator(self):
6868

6969
self.assertTrue(agree.all())
7070

71+
def test_box_prompt_generator(self):
72+
from micro_sam.prompt_generators import PointAndBoxPromptGenerator
73+
from micro_sam.util import get_cell_center_coordinates
74+
75+
labels = self._get_test_data()
76+
label_ids = np.unique(labels)[1:]
77+
78+
centers, boxes = get_cell_center_coordinates(labels)
79+
generator = PointAndBoxPromptGenerator(0, 0, dilation_strength=0, get_point_prompts=False, get_box_prompts=True)
80+
81+
for label_id in label_ids:
82+
center, box_ = centers.get(label_id), boxes.get(label_id)
83+
_, _, box, _ = generator(labels, label_id, center, box_)
84+
coords = np.where(labels == label_id)
85+
expected_box = [coo.min() for coo in coords] + [coo.max() + 1 for coo in coords]
86+
self.assertEqual(expected_box, list(box[0]))
87+
7188

7289
if __name__ == "__main__":
7390
unittest.main()

0 commit comments

Comments
 (0)