@@ -33,8 +33,8 @@ def _download(url, path):
3333 copyfileobj (r_raw , f )
3434
3535
36- def _get_checkpoint (model_type , my_ckpt_path = None ):
37- if my_ckpt_path is None :
36+ def _get_checkpoint (model_type , checkpoint_path = None ):
37+ if checkpoint_path is None :
3838 checkpoint_url = MODEL_URLS [model_type ]
3939 checkpoint_name = checkpoint_url .split ("/" )[- 1 ]
4040 checkpoint_path = os .path .join (CHECKPOINT_FOLDER , checkpoint_name )
@@ -43,13 +43,13 @@ def _get_checkpoint(model_type, my_ckpt_path=None):
4343 if not os .path .exists (checkpoint_path ):
4444 os .makedirs (CHECKPOINT_FOLDER , exist_ok = True )
4545 _download (checkpoint_url , checkpoint_path )
46- else :
47- checkpoint_path = my_ckpt_path
46+ elif not os . path . exists ( checkpoint_path ) :
47+ raise ValueError ( f"The checkpoint path { checkpoint_path } that was passed does not exist." )
4848
4949 return checkpoint_path
5050
5151
52- def get_sam_model (device = None , model_type = "vit_h" , my_ckpt_path = None ):
52+ def get_sam_model (device = None , model_type = "vit_h" , checkpoint_path = None , return_sam = False ):
5353 """Get the SegmentAnything Predictor.
5454
5555 This function will download the required model checkpoint or load it from file if it
@@ -60,13 +60,18 @@ def get_sam_model(device=None, model_type="vit_h", my_ckpt_path=None):
6060 device [str, torch.device] - the device for the model. If none is given will use GPU if available.
6161 (default: None)
6262 model_type [str] - the SegmentAnything model to use. (default: vit_h)
63+ checkpoint_path [str] - the path to the corresponding checkpoint if it is already present
64+ and not in the default model folder. (default: None)
65+ return_sam [bool] - return the sam model object as well as the predictor (default: False)
6366 """
64- checkpoint = _get_checkpoint (model_type , my_ckpt_path )
67+ checkpoint = _get_checkpoint (model_type , checkpoint_path )
6568 device = "cuda" if torch .cuda .is_available () else "cpu"
6669 sam = sam_model_registry [model_type ](checkpoint = checkpoint )
6770 sam .to (device = device )
6871 predictor = SamPredictor (sam )
69- return sam , predictor
72+ if return_sam :
73+ return predictor , sam
74+ return predictor
7075
7176
7277def _compute_2d (input_ , predictor ):
0 commit comments