Skip to content

Commit d89cf3e

Browse files
Merge pull request #140 from computational-cell-analytics/update-instance-seg
Use normal instance segmentation functionality in 2d annotator and pr…
2 parents d99f261 + e758122 commit d89cf3e

File tree

2 files changed

+46
-21
lines changed

2 files changed

+46
-21
lines changed

micro_sam/instance_segmentation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def mask_data_to_segmentation(
5454
masks: List[Dict[str, Any]],
5555
shape: tuple[int, ...],
5656
with_background: bool,
57+
min_object_size: int = 0,
5758
) -> np.ndarray:
5859
"""Convert the output of the automatic mask generation to an instance segmentation.
5960
@@ -63,15 +64,20 @@ def mask_data_to_segmentation(
6364
shape: The image shape.
6465
with_background: Whether the segmentation has background. If yes this function assures that the largest
6566
object in the output will be mapped to zero (the background value).
67+
min_object_size: The minimal size of an object in pixels.
6668
Returns:
6769
The instance segmentation.
6870
"""
6971

7072
masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
7173
segmentation = np.zeros(shape[:2], dtype="uint32")
7274

73-
for seg_id, mask in enumerate(masks, 1):
75+
seg_id = 1
76+
for mask in masks:
77+
if mask["area"] < min_object_size:
78+
continue
7479
segmentation[mask["segmentation"]] = seg_id
80+
seg_id += 1
7581

7682
if with_background:
7783
seg_ids, sizes = np.unique(segmentation, return_counts=True)

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import pickle
13
import warnings
24
from typing import Optional, Tuple
35

@@ -38,15 +40,13 @@ def _segment_widget(v: Viewer) -> None:
3840
v.layers["current_object"].refresh()
3941

4042

41-
def _get_amg(is_tiled, with_background, min_initial_size, box_extension=0.05):
43+
def _get_amg(is_tiled, with_background=True, box_extension=0.05):
4244
if is_tiled:
43-
amg = instance_segmentation.TiledEmbeddingMaskGenerator(
44-
PREDICTOR, with_background=with_background, min_initial_size=min_initial_size, box_extension=box_extension,
45+
amg = instance_segmentation.TiledAutomaticMaskGenerator(
46+
PREDICTOR, with_background=with_background, box_extension=box_extension,
4547
)
4648
else:
47-
amg = instance_segmentation.EmbeddingMaskGenerator(
48-
PREDICTOR, min_initial_size=min_initial_size, box_extension=box_extension,
49-
)
49+
amg = instance_segmentation.EmbeddingMaskGenerator(PREDICTOR, box_extension=box_extension)
5050
return amg
5151

5252

@@ -62,29 +62,27 @@ def _changed_param(amg, **params):
6262
@magicgui(call_button="Automatic Segmentation")
6363
def _autosegment_widget(
6464
v: Viewer,
65-
with_background: bool = True,
6665
pred_iou_thresh: float = 0.88,
6766
stability_score_thresh: float = 0.95,
68-
min_initial_size: int = 10,
69-
box_extension: float = 0.05,
67+
min_object_size: int = 25,
7068
) -> None:
7169
global AMG
7270
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
73-
param_changed = _changed_param(
74-
AMG, with_background=with_background, min_initial_size=min_initial_size, box_extension=box_extension
75-
)
76-
if AMG is None or param_changed:
77-
if param_changed:
78-
print(f"The parameter {param_changed} was changed, so the full instance segmentation has to be recomputed.")
79-
AMG = _get_amg(is_tiled, with_background, min_initial_size, box_extension)
71+
if AMG is None:
72+
AMG = _get_amg(is_tiled)
8073

8174
if not AMG.is_initialized:
8275
AMG.initialize(v.layers["raw"].data, image_embeddings=IMAGE_EMBEDDINGS, verbose=True)
8376

84-
seg = AMG.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
85-
if not is_tiled:
86-
shape = v.layers["raw"].data.shape[:2]
87-
seg = instance_segmentation.mask_data_to_segmentation(seg, shape, with_background)
77+
seg = AMG.generate(
78+
pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh,
79+
min_mask_region_area=min_object_size
80+
)
81+
82+
shape = v.layers["raw"].data.shape[:2]
83+
seg = instance_segmentation.mask_data_to_segmentation(
84+
seg, shape, with_background=True, min_object_size=min_object_size
85+
)
8886
assert isinstance(seg, np.ndarray)
8987

9088
v.layers["auto_segmentation"].data = seg
@@ -192,6 +190,25 @@ def _update_viewer(v, raw, show_embeddings, segmentation_result):
192190
v.layers["current_object"].data = np.zeros(shape, dtype="uint32")
193191

194192

193+
def _precompute_amg_state(raw, save_path):
194+
global AMG
195+
196+
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
197+
AMG = _get_amg(is_tiled)
198+
199+
save_path_amg = os.path.join(save_path, "amg_state.pickle")
200+
if os.path.exists(save_path_amg):
201+
with open(save_path_amg, "rb") as f:
202+
amg_state = pickle.load(f)
203+
AMG.set_state(amg_state)
204+
return
205+
206+
print("Precomputing the state for instance segmentation")
207+
AMG.initialize(raw, image_embeddings=IMAGE_EMBEDDINGS, verbose=True)
208+
with open(save_path_amg, "wb") as f:
209+
pickle.dump(AMG.get_state(), f)
210+
211+
195212
def annotator_2d(
196213
raw: np.ndarray,
197214
embedding_path: Optional[str] = None,
@@ -241,6 +258,8 @@ def annotator_2d(
241258
PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo,
242259
wrong_file_callback=show_wrong_file_warning
243260
)
261+
if embedding_path is not None:
262+
_precompute_amg_state(raw, embedding_path)
244263

245264
# we set the pre-computed image embeddings if we don't use tiling
246265
# (if we use tiling we cannot directly set it because the tile will be chosen dynamically)

0 commit comments

Comments
 (0)