@@ -19,7 +19,7 @@ class SamGeo:
1919 def __init__ (
2020 self ,
2121 model_type = "vit_h" ,
22- checkpoint = "sam_vit_h_4b8939.pth" ,
22+ checkpoint = None ,
2323 automatic = True ,
2424 device = None ,
2525 sam_kwargs = None ,
@@ -31,7 +31,7 @@ def __init__(
3131 Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
3232 checkpoint (str, optional): The path to the model checkpoint. It can be one of the following:
3333 sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
34- Defaults to "sam_vit_h_4b8939.pth" . See https://bit.ly/3VrpxUh for more details.
34+ Defaults to None . See https://bit.ly/3VrpxUh for more details.
3535 automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
3636 The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
3737 device (str, optional): The device to use. It can be one of the following: cpu, cuda.
@@ -58,12 +58,36 @@ def __init__(
5858 CACHE_PATH = os .environ .get (
5959 "TORCH_HOME" , os .path .expanduser ("~/.cache/torch/hub/checkpoints" )
6060 )
61+
62+ model_types = {
63+ "vit_h" : "sam_vit_h_4b8939.pth" ,
64+ "vit_l" : "sam_vit_l_0b3195.pth" ,
65+ "vit_b" : "sam_vit_b_01ec64.pth" ,
66+ }
67+
68+ if model_type not in model_types :
69+ raise ValueError (
70+ f"Model type { model_type } is not supported. It must be one of the following: { model_types } ."
71+ )
72+
73+ checkpoints = {
74+ "sam_vit_h_4b8939.pth" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" ,
75+ "sam_vit_l_0b3195.pth" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" ,
76+ "sam_vit_b_01ec64.pth" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" ,
77+ }
78+
79+ if checkpoint is None :
80+ url = checkpoints [model_types [model_type ]]
81+ checkpoint = model_types [model_type ]
82+ else :
83+ url = None
84+
6185 if not os .path .exists (checkpoint ):
6286 basename = os .path .basename (checkpoint )
6387 checkpoint = os .path .join (CACHE_PATH , basename )
6488 if not os .path .exists (checkpoint ):
6589 print (f"Checkpoint { checkpoint } does not exist." )
66- download_checkpoint (output = checkpoint )
90+ download_checkpoint (url = url , output = checkpoint )
6791 self .checkpoint = checkpoint
6892
6993 # Use cuda if available
0 commit comments