Skip to content

Commit d09c54b

Browse files
Update utility functionality
1 parent 2badb1a commit d09c54b

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

micro_sam/util.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def _download(url, path):
3333
copyfileobj(r_raw, f)
3434

3535

36-
def _get_checkpoint(model_type, my_ckpt_path=None):
37-
if my_ckpt_path is None:
36+
def _get_checkpoint(model_type, checkpoint_path=None):
37+
if checkpoint_path is None:
3838
checkpoint_url = MODEL_URLS[model_type]
3939
checkpoint_name = checkpoint_url.split("/")[-1]
4040
checkpoint_path = os.path.join(CHECKPOINT_FOLDER, checkpoint_name)
@@ -43,13 +43,13 @@ def _get_checkpoint(model_type, my_ckpt_path=None):
4343
if not os.path.exists(checkpoint_path):
4444
os.makedirs(CHECKPOINT_FOLDER, exist_ok=True)
4545
_download(checkpoint_url, checkpoint_path)
46-
else:
47-
checkpoint_path = my_ckpt_path
46+
elif not os.path.exists(checkpoint_path):
47+
raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.")
4848

4949
return checkpoint_path
5050

5151

52-
def get_sam_model(device=None, model_type="vit_h", my_ckpt_path=None):
52+
def get_sam_model(device=None, model_type="vit_h", checkpoint_path=None, return_sam=False):
5353
"""Get the SegmentAnything Predictor.
5454
5555
This function will download the required model checkpoint or load it from file if it
@@ -60,13 +60,18 @@ def get_sam_model(device=None, model_type="vit_h", my_ckpt_path=None):
6060
device [str, torch.device] - the device for the model. If none is given will use GPU if available.
6161
(default: None)
6262
model_type [str] - the SegmentAnything model to use. (default: vit_h)
63+
checkpoint_path [str] - the path to the corresponding checkpoint if it is already present
64+
and not in the default model folder. (default: None)
65+
return_sam [bool] - return the sam model object as well as the predictor (default: False)
6366
"""
64-
checkpoint = _get_checkpoint(model_type, my_ckpt_path)
67+
checkpoint = _get_checkpoint(model_type, checkpoint_path)
6568
device = "cuda" if torch.cuda.is_available() else "cpu"
6669
sam = sam_model_registry[model_type](checkpoint=checkpoint)
6770
sam.to(device=device)
6871
predictor = SamPredictor(sam)
69-
return sam, predictor
72+
if return_sam:
73+
return predictor, sam
74+
return predictor
7075

7176

7277
def _compute_2d(input_, predictor):

0 commit comments

Comments
 (0)