77from napari import Viewer
88
99from .. import util
10- from .. import segment_instances
10+ from .. import instance_segmentation
1111from ..visualization import project_embeddings_for_visualization
1212from .util import (
1313 clear_all_prompts , commit_segmentation_widget , create_prompt_menu ,
@@ -39,6 +39,29 @@ def segment_wigdet(v: Viewer):
3939 v .layers ["current_object" ].refresh ()
4040
4141
42+ def _get_amg (is_tiled , with_background , min_initial_size , use_box , use_mask , use_points , box_extension ):
43+ if is_tiled :
44+ amg = instance_segmentation .TiledEmbeddingMaskGenerator (
45+ PREDICTOR , with_background = with_background , min_initial_size = min_initial_size ,
46+ use_box = use_box , use_mask = use_mask , use_points = use_points , box_extension = box_extension ,
47+ )
48+ else :
49+ amg = instance_segmentation .EmbeddingMaskGenerator (
50+ PREDICTOR , min_initial_size = min_initial_size ,
51+ use_box = use_box , use_mask = use_mask , use_points = use_points , box_extension = box_extension ,
52+ )
53+ return amg
54+
55+
56+ def _changed_param (amg , ** params ):
57+ if amg is None :
58+ return None
59+ for name , val in params .items ():
60+ if hasattr (amg , name ) and getattr (amg , name ) != val :
61+ return name
62+ return None
63+
64+
4265@magicgui (call_button = "Automatic Segmentation" )
4366def autosegment_widget (
4467 v : Viewer ,
@@ -51,23 +74,27 @@ def autosegment_widget(
5174 use_points : bool = False ,
5275 box_extension : float = 0.1 ,
5376):
77+ global AMG
5478 is_tiled = IMAGE_EMBEDDINGS ["input_size" ] is None
55- if is_tiled :
56- seg = segment_instances .segment_instances_from_embeddings_with_tiling (
57- PREDICTOR , IMAGE_EMBEDDINGS , with_background = with_background ,
58- box_extension = box_extension , pred_iou_thresh = pred_iou_thresh ,
59- stability_score_thresh = stability_score_thresh ,
60- min_initial_size = min_initial_size ,
61- use_box = use_box , use_points = use_points , use_mask = use_mask ,
62- )
63- else :
64- seg = segment_instances .segment_instances_from_embeddings (
65- PREDICTOR , IMAGE_EMBEDDINGS , with_background = with_background ,
66- box_extension = box_extension , pred_iou_thresh = pred_iou_thresh ,
67- stability_score_thresh = stability_score_thresh ,
68- min_initial_size = min_initial_size ,
69- use_box = use_box , use_points = use_points , use_mask = use_mask ,
70- )
79+ param_changed = _changed_param (
80+ AMG , with_background = with_background , min_initial_size = min_initial_size ,
81+ use_box = use_box , use_mask = use_mask , use_points = use_points ,
82+ box_extension = box_extension ,
83+ )
84+ if AMG is None or param_changed :
85+ if param_changed :
86+ print (f"The parameter { param_changed } was changed, so the full instance segmentation has to be recomputed." )
87+ AMG = _get_amg (is_tiled , with_background , min_initial_size , use_box , use_mask , use_points , box_extension )
88+
89+ if not AMG .is_initialized :
90+ AMG .initialize (v .layers ["raw" ].data , image_embeddings = IMAGE_EMBEDDINGS , verbose = True )
91+
92+ seg = AMG .generate (pred_iou_thresh = pred_iou_thresh , stability_score_thresh = stability_score_thresh )
93+ if not is_tiled :
94+ shape = v .layers ["raw" ].data .shape [:2 ]
95+ seg = instance_segmentation .mask_data_to_segmentation (seg , shape , with_background )
96+ assert isinstance (seg , np .ndarray )
97+
7198 v .layers ["auto_segmentation" ].data = seg
7299 v .layers ["auto_segmentation" ].refresh ()
73100
@@ -180,7 +207,8 @@ def annotator_2d(
180207 predictor = None ,
181208):
182209 # for access to the predictor and the image embeddings in the widgets
183- global PREDICTOR , IMAGE_EMBEDDINGS
210+ global PREDICTOR , IMAGE_EMBEDDINGS , AMG
211+ AMG = None
184212
185213 if predictor is None :
186214 PREDICTOR = util .get_sam_model (model_type = model_type )
0 commit comments