Skip to content

Commit 5c53672

Browse files
Implement progress bar
1 parent 659f225 commit 5c53672

File tree

4 files changed

+54
-19
lines changed

4 files changed

+54
-19
lines changed

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from magicgui import magicgui
55
from napari import Viewer
6+
from napari.utils import progress
67

78
from .. import util
89
from ..segment_from_prompts import segment_from_mask, segment_from_points
@@ -21,14 +22,19 @@
2122
# TODO refactor
2223
def _segment_volume(
2324
seg, predictor, image_embeddings, segmented_slices,
24-
stop_lower, stop_upper, iou_threshold, method
25+
stop_lower, stop_upper, iou_threshold, method,
26+
progress_bar=None,
2527
):
2628
assert method in ("mask", "bounding_box")
2729
if method == "mask":
2830
use_mask, use_box = True, True
2931
else:
3032
use_mask, use_box = False, True
3133

34+
def _update_progress():
35+
if progress_bar is not None:
36+
progress_bar.update(1)
37+
3238
# TODO refactor to utils so that it can be used by other plugins
3339
def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None, verbose=False):
3440
z = z_start + increment
@@ -50,6 +56,7 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
5056
if verbose:
5157
print(f"Segment {z_start} to {z_stop}: stop at slice {z}")
5258
break
59+
_update_progress()
5360

5461
z0, z1 = int(segmented_slices.min()), int(segmented_slices.max())
5562

@@ -75,6 +82,7 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
7582
seg_prompt = np.logical_or(seg[z_start] == 1, seg[z_stop] == 1)
7683
seg[z] = segment_from_mask(predictor, seg_prompt, image_embeddings=image_embeddings, i=z,
7784
use_mask=use_mask, use_box=use_box)
85+
_update_progress()
7886

7987
else: # there is a range of more than 2 slices in between -> segment ranges
8088
# segment from bottom
@@ -89,6 +97,7 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
8997
seg_prompt = np.logical_or(seg[z_mid - 1] == 1, seg[z_mid + 1] == 1)
9098
seg[z_mid] = segment_from_mask(predictor, seg_prompt, image_embeddings=image_embeddings, i=z_mid,
9199
use_mask=use_mask, use_box=use_box)
100+
_update_progress()
92101

93102
return seg
94103

@@ -118,16 +127,20 @@ def segment_slice_wigdet(v: Viewer):
118127
def segment_volume_widget(v: Viewer, iou_threshold: float = 0.8, method: str = "mask"):
119128
# step 1: segment all slices with prompts
120129
shape = v.layers["raw"].data.shape
121-
seg, slices, stop_lower, stop_upper = segment_slices_with_prompts(
122-
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape
123-
)
124130

125-
# step 2: segment the rest of the volume based on smart prompting
126-
seg = _segment_volume(
127-
seg, PREDICTOR, IMAGE_EMBEDDINGS, slices,
128-
stop_lower, stop_upper,
129-
iou_threshold=iou_threshold, method=method,
130-
)
131+
with progress(total=shape[0]) as progress_bar:
132+
133+
seg, slices, stop_lower, stop_upper = segment_slices_with_prompts(
134+
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape, progress_bar=progress_bar,
135+
)
136+
137+
# step 2: segment the rest of the volume based on smart prompting
138+
seg = _segment_volume(
139+
seg, PREDICTOR, IMAGE_EMBEDDINGS, slices,
140+
stop_lower, stop_upper,
141+
iou_threshold=iou_threshold, method=method,
142+
progress_bar=progress_bar,
143+
)
131144

132145
v.layers["current_object"].data = seg
133146
v.layers["current_object"].refresh()

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from magicgui import magicgui
55
from napari import Viewer
6+
from napari.utils import progress
67

78
from .. import util
89
from ..segment_from_prompts import segment_from_mask, segment_from_points
@@ -19,13 +20,17 @@
1920

2021
# TODO motion model!!!
2122
# TODO handle divison annotations + division classifier
22-
def _track_from_prompts(seg, predictor, slices, image_embeddings, stop_upper, threshold, method):
23+
def _track_from_prompts(seg, predictor, slices, image_embeddings, stop_upper, threshold, method, progress_bar=None):
2324
assert method in ("mask", "bounding_box")
2425
if method == "mask":
2526
use_mask, use_box = True, True
2627
else:
2728
use_mask, use_box = False, True
2829

30+
def _update_progress():
31+
if progress_bar is not None:
32+
progress_bar.update(1)
33+
2934
t0 = int(slices.min())
3035
t = t0 + 1
3136
while True:
@@ -36,6 +41,7 @@ def _track_from_prompts(seg, predictor, slices, image_embeddings, stop_upper, th
3641
seg_prev = seg[t - 1]
3742
seg_t = segment_from_mask(predictor, seg_prev, image_embeddings=image_embeddings, i=t,
3843
use_mask=use_mask, use_box=use_box)
44+
_update_progress()
3945

4046
if (threshold is not None) and (seg_prev is not None):
4147
iou = util.compute_iou(seg_prev, seg_t)
@@ -79,14 +85,18 @@ def segment_frame_wigdet(v: Viewer):
7985

8086
@magicgui(call_button="Track Object [V]", method={"choices": ["bounding_box", "mask"]})
8187
def track_objet_widget(v: Viewer, iou_threshold: float = 0.8, method: str = "mask"):
82-
# step 1: segment all slices with prompts
8388
shape = v.layers["raw"].data.shape
84-
seg, slices, _, stop_upper = segment_slices_with_prompts(
85-
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape
86-
)
8789

88-
# step 2: track the object starting from the lowest annotated slice
89-
seg = _track_from_prompts(seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, iou_threshold, method)
90+
with progress(total=shape[0]) as progress_bar:
91+
# step 1: segment all slices with prompts
92+
seg, slices, _, stop_upper = segment_slices_with_prompts(
93+
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape, progress_bar=progress_bar
94+
)
95+
96+
# step 2: track the object starting from the lowest annotated slice
97+
seg = _track_from_prompts(
98+
seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, iou_threshold, method, progress_bar=progress_bar
99+
)
90100

91101
v.layers["current_track"].data = seg
92102
v.layers["current_track"].refresh()

micro_sam/sam_annotator/util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,37 @@ def prompt_layer_to_points(prompt_layer, i=None):
7676
return this_points, this_labels
7777

7878

79-
def segment_slices_with_prompts(predictor, prompt_layer, image_embeddings, shape):
79+
def segment_slices_with_prompts(predictor, prompt_layer, image_embeddings, shape, progress_bar=None):
8080
seg = np.zeros(shape, dtype="uint32")
8181

8282
slices = np.unique(prompt_layer.data[:, 0]).astype("int")
8383
stop_lower, stop_upper = False, False
8484

85+
def _update_progress():
86+
if progress_bar is not None:
87+
progress_bar.update(1)
88+
8589
for i in slices:
8690
prompts_i = prompt_layer_to_points(prompt_layer, i)
8791

8892
# TODO also take into account division properties once we have this implemented in tracking
8993
# do we end the segmentation at the outer slices?
9094
if prompts_i is None:
95+
9196
if i == slices[0]:
9297
stop_lower = True
9398
elif i == slices[-1]:
9499
stop_upper = True
95100
else:
96101
raise RuntimeError("Stop slices can only be at the start or end")
102+
97103
seg[i] = 0
104+
_update_progress()
98105
continue
99106

100107
points, labels = prompts_i
101108
seg_i = segment_from_points(predictor, points, labels, image_embeddings=image_embeddings, i=i)
102109
seg[i] = seg_i
110+
_update_progress()
103111

104112
return seg, slices, stop_lower, stop_upper

micro_sam/util.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from skimage.measure import regionprops
1111

1212
from segment_anything import sam_model_registry, SamPredictor
13-
from tqdm import tqdm
13+
14+
try:
15+
from napari.utils import progress as tqdm
16+
except ImportError:
17+
from tqdm import tqdm
1418

1519
MODEL_URLS = {
1620
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",

0 commit comments

Comments
 (0)