Skip to content

Commit 7b6166a

Browse files
Implement segment_from_box
1 parent 618cfb6 commit 7b6166a

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

micro_sam/segment_from_prompts.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,22 @@ def _compute_box(mask, original_size=None):
1212
coords[1].min(), coords[0].min(),
1313
coords[1].max() + 1, coords[0].max() + 1,
1414
])
15-
# FIXME how do we deal with aspect ratios???
15+
# TODO how do we deal with aspect ratios???
1616
if original_size is not None:
1717
trafo = ResizeLongestSide(max(original_size))
1818
box = trafo.apply_boxes(box[None], (256, 256)).squeeze()
1919
return box
2020

2121

22+
def _process_box(box, original_size=None):
23+
box_processed = box[[1, 0, 3, 2]]
24+
# TODO how do we deal with aspect ratios???
25+
if original_size is not None:
26+
trafo = ResizeLongestSide(max(original_size))
27+
box_processed = trafo.apply_boxes(box[None], (256, 256)).squeeze()
28+
return box_processed
29+
30+
2231
def _compute_logits(mask, eps=1e-3):
2332

2433
def inv_sigmoid(x):
@@ -77,3 +86,17 @@ def segment_from_mask(
7786
return mask, scores, logits
7887
else:
7988
return mask
89+
90+
91+
def segment_from_box(
92+
predictor, box,
93+
image_embeddings=None, i=None, original_size=None,
94+
multimask_output=False, return_all=False,
95+
):
96+
if image_embeddings is not None:
97+
util.set_precomputed(predictor, image_embeddings, i)
98+
mask, scores, logits = predictor.predict(box=_process_box(box, original_size), multimask_output=multimask_output)
99+
if return_all:
100+
return mask, scores, logits
101+
else:
102+
return mask

test/test_segment_from_prompts.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ def test_segment_from_mask(self):
5151
predicted = segment_from_mask(predictor, mask, use_mask=True, use_box=False)
5252
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
5353

54+
def test_segment_from_box(self):
55+
from micro_sam.segment_from_prompts import segment_from_box
56+
57+
mask, image = self._get_input()
58+
predictor = self._get_model(image)
59+
60+
box = np.array([106, 106, 150, 150])
61+
predicted = segment_from_box(predictor, box)
62+
self.assertGreater(util.compute_iou(mask, predicted), 0.9)
63+
5464

5565
if __name__ == "__main__":
5666
unittest.main()

0 commit comments

Comments
 (0)