Skip to content

Commit c8209f2

Browse files
Merge pull request #21 from computational-cell-analytics/seg-with-bg
Implement instance segmentation with background
2 parents 4cc9d0d + c5ce8c3 commit c8209f2

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def segment_wigdet(v: Viewer):
2323

2424
# TODO expose more parameters
2525
@magicgui(call_button="Segment All Objects", method={"choices": ["default", "sam", "embeddings"]})
26-
def autosegment_widget(v: Viewer, method: str = "default"):
26+
def autosegment_widget(v: Viewer, method: str = "default", with_background: bool = True):
2727
if method in ("default", "sam"):
2828
print("Run automatic segmentation with SAM. This can take a few minutes ...")
2929
image = v.layers["raw"].data
30-
seg = segment_instances.segment_instances_sam(SAM, image)
30+
seg = segment_instances.segment_instances_sam(SAM, image, with_background=with_background)
3131
elif method == "embeddings":
3232
seg = segment_instances.segment_instances_from_embeddings(PREDICTOR, IMAGE_EMBEDDINGS)
3333
else:

micro_sam/segment_instances.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#
2020

2121

22-
def segment_instances_sam(sam, image, **kwargs):
22+
def segment_instances_sam(sam, image, with_background=False, **kwargs):
2323
segmentor = SamAutomaticMaskGenerator(sam, **kwargs)
2424

2525
image_ = util._to_image(image)
@@ -30,6 +30,12 @@ def segment_instances_sam(sam, image, **kwargs):
3030
for seg_id, mask in enumerate(masks, 1):
3131
segmentation[mask["segmentation"]] = seg_id
3232

33+
seg_ids, sizes = np.unique(segmentation, return_counts=True)
34+
if with_background and 0 not in seg_ids:
35+
bg_id = seg_ids[np.argmax(seg_ids)]
36+
segmentation[segmentation == bg_id] = 0
37+
vigra.analysis.relabelConsecutive(segmentation, out=segmentation)
38+
3339
return segmentation
3440

3541

test/test_segment_instances.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@
1010
class TestSegmentInstances(unittest.TestCase):
1111

1212
# create an input image with three objects
13-
def _get_input(self, shape=(128, 128)):
13+
def _get_input(self, shape=(96, 96)):
1414
mask = np.zeros(shape, dtype="uint8")
1515

1616
def write_object(center, radius):
1717
circle = disk(center, radius, shape=shape)
1818
mask[circle] = 1
1919

2020
center = tuple(sh // 4 for sh in shape)
21-
write_object(center, radius=10)
21+
write_object(center, radius=8)
2222

2323
center = tuple(sh // 2 for sh in shape)
2424
write_object(center, radius=9)
2525

2626
center = tuple(3 * sh // 4 for sh in shape)
27-
write_object(center, radius=11)
27+
write_object(center, radius=7)
2828

2929
image = mask * 255
3030
return mask, image

0 commit comments

Comments
 (0)