Skip to content

Commit 0e11316

Browse files
Update annotator 2d to use new mask generators for more interactive instance segmentation
1 parent 9cd349b commit 0e11316

File tree

1 file changed

+45
-17
lines changed

1 file changed

+45
-17
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,29 @@ def segment_wigdet(v: Viewer):
3939
v.layers["current_object"].refresh()
4040

4141

42+
def _get_amg(is_tiled, with_background, min_initial_size, use_box, use_mask, use_points, box_extension):
43+
if is_tiled:
44+
amg = instance_segmentation.TiledEmbeddingMaskGenerator(
45+
PREDICTOR, with_background=with_background, min_initial_size=min_initial_size,
46+
use_box=use_box, use_mask=use_mask, use_points=use_points, box_extension=box_extension,
47+
)
48+
else:
49+
amg = instance_segmentation.EmbeddingMaskGenerator(
50+
PREDICTOR, min_initial_size=min_initial_size,
51+
use_box=use_box, use_mask=use_mask, use_points=use_points, box_extension=box_extension,
52+
)
53+
return amg
54+
55+
56+
def _changed_param(amg, **params):
57+
if amg is None:
58+
return None
59+
for name, val in params.items():
60+
if hasattr(amg, name) and getattr(amg, name) != val:
61+
return name
62+
return None
63+
64+
4265
@magicgui(call_button="Automatic Segmentation")
4366
def autosegment_widget(
4467
v: Viewer,
@@ -51,23 +74,27 @@ def autosegment_widget(
5174
use_points: bool = False,
5275
box_extension: float = 0.1,
5376
):
77+
global AMG
5478
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
55-
if is_tiled:
56-
seg = instance_segmentation.segment_instances_from_embeddings_with_tiling(
57-
PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background,
58-
box_extension=box_extension, pred_iou_thresh=pred_iou_thresh,
59-
stability_score_thresh=stability_score_thresh,
60-
min_initial_size=min_initial_size,
61-
use_box=use_box, use_points=use_points, use_mask=use_mask,
62-
)
63-
else:
64-
seg = instance_segmentation.segment_instances_from_embeddings(
65-
PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background,
66-
box_extension=box_extension, pred_iou_thresh=pred_iou_thresh,
67-
stability_score_thresh=stability_score_thresh,
68-
min_initial_size=min_initial_size,
69-
use_box=use_box, use_points=use_points, use_mask=use_mask,
70-
)
79+
param_changed = _changed_param(
80+
AMG, with_background=with_background, min_initial_size=min_initial_size,
81+
use_box=use_box, use_mask=use_mask, use_points=use_points,
82+
box_extension=box_extension,
83+
)
84+
if AMG is None or param_changed:
85+
if param_changed:
86+
print(f"The parameter {param_changed} was changed, so the full instance segmentation has to be recomputed.")
87+
AMG = _get_amg(is_tiled, with_background, min_initial_size, use_box, use_mask, use_points, box_extension)
88+
89+
if not AMG.is_initialized:
90+
AMG.initialize(v.layers["raw"].data, image_embeddings=IMAGE_EMBEDDINGS, verbose=True)
91+
92+
seg = AMG.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
93+
if not is_tiled:
94+
shape = v.layers["raw"].data.shape[:2]
95+
seg = instance_segmentation.mask_data_to_segmentation(seg, shape, with_background)
96+
assert isinstance(seg, np.ndarray)
97+
7198
v.layers["auto_segmentation"].data = seg
7299
v.layers["auto_segmentation"].refresh()
73100

@@ -180,7 +207,8 @@ def annotator_2d(
180207
predictor=None,
181208
):
182209
# for access to the predictor and the image embeddings in the widgets
183-
global PREDICTOR, IMAGE_EMBEDDINGS
210+
global PREDICTOR, IMAGE_EMBEDDINGS, AMG
211+
AMG = None
184212

185213
if predictor is None:
186214
PREDICTOR = util.get_sam_model(model_type=model_type)

0 commit comments

Comments
 (0)