Skip to content

Commit bada8fa

Browse files
Enable passing custom predictor to all annotator functions
1 parent 449e025 commit bada8fa

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from magicgui import magicgui
88
from napari import Viewer
99
from napari.utils import progress
10+
from segment_anything import SamPredictor
1011

1112
from .. import util
1213
from ..prompt_based_segmentation import segment_from_mask
@@ -195,10 +196,15 @@ def annotator_3d(
195196
tile_shape: Optional[Tuple[int, int]] = None,
196197
halo: Optional[Tuple[int, int]] = None,
197198
return_viewer: bool = False,
199+
predictor: Optional[SamPredictor] = None,
198200
) -> None:
199201
# for access to the predictor and the image embeddings in the widgets
200202
global PREDICTOR, IMAGE_EMBEDDINGS
201-
PREDICTOR = util.get_sam_model(model_type=model_type)
203+
204+
if predictor is None:
205+
PREDICTOR = util.get_sam_model(model_type=model_type)
206+
else:
207+
PREDICTOR = predictor
202208
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
203209
PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo,
204210
wrong_file_callback=show_wrong_file_warning,

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from napari import Viewer
1010
from napari.utils import progress
1111
from scipy.ndimage import shift
12+
from segment_anything import SamPredictor
1213

1314
# this is more precise for comuting the centers, but slow!
1415
# from vigra.filters import eccentricityCenters
@@ -355,12 +356,16 @@ def annotator_tracking(
355356
tile_shape: Optional[Tuple[int, int]] = None,
356357
halo: Optional[Tuple[int, int]] = None,
357358
return_viewer: bool = False,
359+
predictor: Optional[SamPredictor] = None,
358360
) -> None:
359361
# global state
360362
global PREDICTOR, IMAGE_EMBEDDINGS, CURRENT_TRACK_ID, LINEAGE
361363
global TRACKING_WIDGET
362364

363-
PREDICTOR = util.get_sam_model(model_type=model_type)
365+
if predictor is None:
366+
PREDICTOR = util.get_sam_model(model_type=model_type)
367+
else:
368+
PREDICTOR = predictor
364369
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
365370
PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo,
366371
wrong_file_callback=show_wrong_file_warning,

micro_sam/sam_annotator/image_series_annotator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from magicgui import magicgui
1010
from napari.utils import progress as tqdm
11+
from segment_anything import SamPredictor
12+
1113
from .annotator_2d import annotator_2d
1214
from .. import util
1315

@@ -32,6 +34,7 @@ def image_series_annotator(
3234
image_files: List[str],
3335
output_folder: str,
3436
embedding_path: Optional[str] = None,
37+
predictor: Optional[SamPredictor] = None,
3538
**kwargs
3639
) -> None:
3740
"""
@@ -45,7 +48,8 @@ def image_series_annotator(
4548
os.makedirs(output_folder, exist_ok=True)
4649
next_image_id = 0
4750

48-
predictor = util.get_sam_model(model_type=kwargs.get("model_type", "vit_h"))
51+
if predictor is None:
52+
predictor = util.get_sam_model(model_type=kwargs.get("model_type", "vit_h"))
4953
if embedding_path is None:
5054
embedding_paths = None
5155
else:
@@ -101,12 +105,13 @@ def image_folder_annotator(
101105
output_folder: str,
102106
pattern: str = "*",
103107
embedding_path: Optional[str] = None,
108+
predictor: Optional[SamPredictor] = None,
104109
**kwargs
105110
) -> None:
106111
"""
107112
"""
108113
image_files = sorted(glob(os.path.join(input_folder, pattern)))
109-
image_series_annotator(image_files, output_folder, embedding_path, **kwargs)
114+
image_series_annotator(image_files, output_folder, embedding_path, predictor, **kwargs)
110115

111116

112117
def main():

0 commit comments

Comments
 (0)