Skip to content

Commit 00e7a8e

Browse files
Merge pull request #62 from computational-cell-analytics/tunable-instance-seg
Tunable instance seg
2 parents cd4a1ea + e4ef8a8 commit 00e7a8e

13 files changed

+1151
-608
lines changed

micro_sam/instance_segmentation.py

Lines changed: 956 additions & 0 deletions
Large diffs are not rendered by default.

micro_sam/sam_annotator/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@
22
from .annotator_3d import annotator_3d
33
from .annotator_tracking import annotator_tracking
44
from .image_series_annotator import image_folder_annotator, image_series_annotator
5-
from .interactive_instance_segmentation import interactive_instance_segmentation

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from napari import Viewer
88

99
from .. import util
10-
from .. import segment_instances
10+
from .. import instance_segmentation
1111
from ..visualization import project_embeddings_for_visualization
1212
from .util import (
1313
clear_all_prompts, commit_segmentation_widget, create_prompt_menu,
@@ -39,6 +39,29 @@ def segment_wigdet(v: Viewer):
3939
v.layers["current_object"].refresh()
4040

4141

42+
def _get_amg(is_tiled, with_background, min_initial_size, use_box, use_mask, use_points, box_extension):
43+
if is_tiled:
44+
amg = instance_segmentation.TiledEmbeddingMaskGenerator(
45+
PREDICTOR, with_background=with_background, min_initial_size=min_initial_size,
46+
use_box=use_box, use_mask=use_mask, use_points=use_points, box_extension=box_extension,
47+
)
48+
else:
49+
amg = instance_segmentation.EmbeddingMaskGenerator(
50+
PREDICTOR, min_initial_size=min_initial_size,
51+
use_box=use_box, use_mask=use_mask, use_points=use_points, box_extension=box_extension,
52+
)
53+
return amg
54+
55+
56+
def _changed_param(amg, **params):
57+
if amg is None:
58+
return None
59+
for name, val in params.items():
60+
if hasattr(amg, name) and getattr(amg, name) != val:
61+
return name
62+
return None
63+
64+
4265
@magicgui(call_button="Automatic Segmentation")
4366
def autosegment_widget(
4467
v: Viewer,
@@ -51,23 +74,27 @@ def autosegment_widget(
5174
use_points: bool = False,
5275
box_extension: float = 0.1,
5376
):
77+
global AMG
5478
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
55-
if is_tiled:
56-
seg = segment_instances.segment_instances_from_embeddings_with_tiling(
57-
PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background,
58-
box_extension=box_extension, pred_iou_thresh=pred_iou_thresh,
59-
stability_score_thresh=stability_score_thresh,
60-
min_initial_size=min_initial_size,
61-
use_box=use_box, use_points=use_points, use_mask=use_mask,
62-
)
63-
else:
64-
seg = segment_instances.segment_instances_from_embeddings(
65-
PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background,
66-
box_extension=box_extension, pred_iou_thresh=pred_iou_thresh,
67-
stability_score_thresh=stability_score_thresh,
68-
min_initial_size=min_initial_size,
69-
use_box=use_box, use_points=use_points, use_mask=use_mask,
70-
)
79+
param_changed = _changed_param(
80+
AMG, with_background=with_background, min_initial_size=min_initial_size,
81+
use_box=use_box, use_mask=use_mask, use_points=use_points,
82+
box_extension=box_extension,
83+
)
84+
if AMG is None or param_changed:
85+
if param_changed:
86+
print(f"The parameter {param_changed} was changed, so the full instance segmentation has to be recomputed.")
87+
AMG = _get_amg(is_tiled, with_background, min_initial_size, use_box, use_mask, use_points, box_extension)
88+
89+
if not AMG.is_initialized:
90+
AMG.initialize(v.layers["raw"].data, image_embeddings=IMAGE_EMBEDDINGS, verbose=True)
91+
92+
seg = AMG.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
93+
if not is_tiled:
94+
shape = v.layers["raw"].data.shape[:2]
95+
seg = instance_segmentation.mask_data_to_segmentation(seg, shape, with_background)
96+
assert isinstance(seg, np.ndarray)
97+
7198
v.layers["auto_segmentation"].data = seg
7299
v.layers["auto_segmentation"].refresh()
73100

@@ -180,7 +207,8 @@ def annotator_2d(
180207
predictor=None,
181208
):
182209
# for access to the predictor and the image embeddings in the widgets
183-
global PREDICTOR, IMAGE_EMBEDDINGS
210+
global PREDICTOR, IMAGE_EMBEDDINGS, AMG
211+
AMG = None
184212

185213
if predictor is None:
186214
PREDICTOR = util.get_sam_model(model_type=model_type)

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from napari.utils import progress
77

88
from .. import util
9-
from ..segment_from_prompts import segment_from_mask
9+
from ..prompt_based_segmentation import segment_from_mask
1010
from ..visualization import project_embeddings_for_visualization
1111
from .util import (
1212
clear_all_prompts, commit_segmentation_widget, create_prompt_menu,

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# from vigra.filters import eccentricityCenters
1212

1313
from .. import util
14-
from ..segment_from_prompts import segment_from_mask
14+
from ..prompt_based_segmentation import segment_from_mask
1515
from .util import (
1616
create_prompt_menu, clear_all_prompts,
1717
prompt_layer_to_boxes, prompt_layer_to_points,

micro_sam/sam_annotator/interactive_instance_segmentation.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

micro_sam/sam_annotator/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from magicgui.widgets import ComboBox, Container
77
from napari import Viewer
88

9-
from ..segment_from_prompts import segment_from_box, segment_from_box_and_points, segment_from_points
9+
from ..prompt_based_segmentation import segment_from_box, segment_from_box_and_points, segment_from_points
1010

1111
# Green and Red
1212
LABEL_COLOR_CYCLE = ["#00FF00", "#FF0000"]

0 commit comments

Comments
 (0)