33
44from magicgui import magicgui
55from napari import Viewer
6+ from napari .utils import progress
67
78from .. import util
89from ..segment_from_prompts import segment_from_mask , segment_from_points
2122# TODO refactor
2223def _segment_volume (
2324 seg , predictor , image_embeddings , segmented_slices ,
24- stop_lower , stop_upper , iou_threshold , method
25+ stop_lower , stop_upper , iou_threshold , method ,
26+ progress_bar = None ,
2527):
2628 assert method in ("mask" , "bounding_box" )
2729 if method == "mask" :
2830 use_mask , use_box = True , True
2931 else :
3032 use_mask , use_box = False , True
3133
34+ def _update_progress ():
35+ if progress_bar is not None :
36+ progress_bar .update (1 )
37+
3238 # TODO refactor to utils so that it can be used by other plugins
3339 def segment_range (z_start , z_stop , increment , stopping_criterion , threshold = None , verbose = False ):
3440 z = z_start + increment
@@ -50,6 +56,7 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
5056 if verbose :
5157 print (f"Segment { z_start } to { z_stop } : stop at slice { z } " )
5258 break
59+ _update_progress ()
5360
5461 z0 , z1 = int (segmented_slices .min ()), int (segmented_slices .max ())
5562
@@ -75,6 +82,7 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
7582 seg_prompt = np .logical_or (seg [z_start ] == 1 , seg [z_stop ] == 1 )
7683 seg [z ] = segment_from_mask (predictor , seg_prompt , image_embeddings = image_embeddings , i = z ,
7784 use_mask = use_mask , use_box = use_box )
85+ _update_progress ()
7886
7987 else : # there is a range of more than 2 slices in between -> segment ranges
8088 # segment from bottom
@@ -89,6 +97,7 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
8997 seg_prompt = np .logical_or (seg [z_mid - 1 ] == 1 , seg [z_mid + 1 ] == 1 )
9098 seg [z_mid ] = segment_from_mask (predictor , seg_prompt , image_embeddings = image_embeddings , i = z_mid ,
9199 use_mask = use_mask , use_box = use_box )
100+ _update_progress ()
92101
93102 return seg
94103
@@ -118,16 +127,20 @@ def segment_slice_wigdet(v: Viewer):
118127def segment_volume_widget (v : Viewer , iou_threshold : float = 0.8 , method : str = "mask" ):
119128 # step 1: segment all slices with prompts
120129 shape = v .layers ["raw" ].data .shape
121- seg , slices , stop_lower , stop_upper = segment_slices_with_prompts (
122- PREDICTOR , v .layers ["prompts" ], IMAGE_EMBEDDINGS , shape
123- )
124130
125- # step 2: segment the rest of the volume based on smart prompting
126- seg = _segment_volume (
127- seg , PREDICTOR , IMAGE_EMBEDDINGS , slices ,
128- stop_lower , stop_upper ,
129- iou_threshold = iou_threshold , method = method ,
130- )
131+ with progress (total = shape [0 ]) as progress_bar :
132+
133+ seg , slices , stop_lower , stop_upper = segment_slices_with_prompts (
134+ PREDICTOR , v .layers ["prompts" ], IMAGE_EMBEDDINGS , shape , progress_bar = progress_bar ,
135+ )
136+
137+ # step 2: segment the rest of the volume based on smart prompting
138+ seg = _segment_volume (
139+ seg , PREDICTOR , IMAGE_EMBEDDINGS , slices ,
140+ stop_lower , stop_upper ,
141+ iou_threshold = iou_threshold , method = method ,
142+ progress_bar = progress_bar ,
143+ )
131144
132145 v .layers ["current_object" ].data = seg
133146 v .layers ["current_object" ].refresh ()
0 commit comments