Skip to content

Commit eda1e3d

Browse files
Enable setting the model type
1 parent 896da61 commit eda1e3d

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def autosegment_widget(v: Viewer, method: str = "default"):
3636
v.layers["auto_segmentation"].refresh()
3737

3838

39-
def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None):
39+
def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None, model_type="vit_h"):
4040
# for access to the predictor and the image embeddings in the widgets
4141
global PREDICTOR, IMAGE_EMBEDDINGS, SAM
4242

43-
PREDICTOR, SAM = util.get_sam_model(return_sam=True)
43+
PREDICTOR, SAM = util.get_sam_model(model_type=model_type, return_sam=True)
4444
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path, ndim=2)
4545
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)
4646

@@ -166,6 +166,9 @@ def main():
166166
"--show_embeddings", action="store_true",
167167
help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging."
168168
)
169+
parser.add_argument(
170+
"--model_type", default="vit_h", help="The segment anything model that will be used, one of vit_h,l,b."
171+
)
169172

170173
args = parser.parse_args()
171174
raw = util.load_image_data(args.input, ndim=2, key=args.key)
@@ -180,5 +183,6 @@ def main():
180183

181184
annotator_2d(
182185
raw, embedding_path=args.embedding_path,
183-
show_embeddings=args.show_embeddings, segmentation_result=segmentation_result
186+
show_embeddings=args.show_embeddings, segmentation_result=segmentation_result,
187+
model_type=args.model_type,
184188
)

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ def segment_volume_widget(v: Viewer, iou_threshold: float = 0.8, projection: str
162162
v.layers["current_object"].refresh()
163163

164164

165-
def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None):
165+
def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_result=None, model_type="vit_h"):
166166
# for access to the predictor and the image embeddings in the widgets
167167
global PREDICTOR, IMAGE_EMBEDDINGS, DEFAULT_PROJECTION
168-
PREDICTOR = util.get_sam_model()
168+
PREDICTOR = util.get_sam_model(model_type=model_type)
169169
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
170170

171171
# the mask projection currently only works for square images
@@ -291,6 +291,9 @@ def main():
291291
"--show_embeddings", action="store_true",
292292
help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging."
293293
)
294+
parser.add_argument(
295+
"--model_type", default="vit_h", help="The segment anything model that will be used, one of vit_h,l,b."
296+
)
294297

295298
args = parser.parse_args()
296299
raw = util.load_image_data(args.input, ndim=3, key=args.key)
@@ -305,5 +308,6 @@ def main():
305308

306309
annotator_3d(
307310
raw, embedding_path=args.embedding_path,
308-
show_embeddings=args.show_embeddings, segmentation_result=segmentation_result
311+
show_embeddings=args.show_embeddings, segmentation_result=segmentation_result,
312+
model_type=args.model_type,
309313
)

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,12 @@ def commit_tracking_widget(v: Viewer, layer: str = "current_track"):
299299
v.layers["prompts"].refresh()
300300

301301

302-
def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking_result=None):
302+
def annotator_tracking(raw, embedding_path=None, show_embeddings=False, tracking_result=None, model_type="vit_h"):
303303
# global state
304304
global PREDICTOR, IMAGE_EMBEDDINGS, CURRENT_TRACK_ID, LINEAGE
305305
global TRACKING_WIDGET
306306

307-
PREDICTOR = util.get_sam_model()
307+
PREDICTOR = util.get_sam_model(model_type=model_type)
308308
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
309309

310310
CURRENT_TRACK_ID = 1
@@ -445,6 +445,9 @@ def main():
445445
"--show_embeddings", action="store_true",
446446
help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging."
447447
)
448+
parser.add_argument(
449+
"--model_type", default="vit_h", help="The segment anything model that will be used, one of vit_h,l,b."
450+
)
448451

449452
args = parser.parse_args()
450453
raw = util.load_image_data(args.input, ndim=3, key=args.key)
@@ -458,5 +461,6 @@ def main():
458461
warnings.warn("You have not passed an embedding_path. Restarting the annotator may take a long time.")
459462

460463
annotator_tracking(
461-
raw, embedding_path=args.embedding_path, show_embeddings=args.show_embeddings, tracking_result=tracking_result
464+
raw, embedding_path=args.embedding_path, show_embeddings=args.show_embeddings,
465+
tracking_result=tracking_result, model_type=args.model_type,
462466
)

0 commit comments

Comments
 (0)