@@ -109,7 +109,7 @@ class LangSAM:
109109 A Language-based Segment-Anything Model (LangSAM) class which combines GroundingDINO and SAM.
110110 """
111111
112- def __init__ (self , model_type = "vit_h" ):
112+ def __init__ (self , model_type = "vit_h" , checkpoint = None ):
113113 """Initialize the LangSAM instance.
114114
115115 Args:
@@ -119,7 +119,7 @@ def __init__(self, model_type="vit_h"):
119119
120120 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
121121 self .build_groundingdino ()
122- self .build_sam (model_type )
122+ self .build_sam (model_type , checkpoint )
123123
124124 self .source = None
125125 self .image = None
@@ -129,17 +129,21 @@ def __init__(self, model_type="vit_h"):
129129 self .logits = None
130130 self .prediction = None
131131
132- def build_sam (self , model_type ):
132+ def build_sam (self , model_type , checkpoint_url = None ):
133133 """Build the SAM model.
134134
135135 Args:
136136 model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
137137 Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
138+ checkpoint_url:
138139 """
139- checkpoint_url = SAM_MODELS [model_type ]
140- sam = sam_model_registry [model_type ]()
141- state_dict = torch .hub .load_state_dict_from_url (checkpoint_url )
142- sam .load_state_dict (state_dict , strict = True )
140+ if checkpoint_url is not None :
141+ sam = sam_model_registry [model_type ](checkpoint = checkpoint_url )
142+ else :
143+ checkpoint_url = SAM_MODELS [model_type ]
144+ sam = sam_model_registry [model_type ]()
145+ state_dict = torch .hub .load_state_dict_from_url (checkpoint_url )
146+ sam .load_state_dict (state_dict , strict = True )
143147 sam .to (device = self .device )
144148 self .sam = SamPredictor (sam )
145149
0 commit comments