Skip to content

Commit 22d26eb

Browse files
authored
Add checkpoint to textsam.LangSAM() (#204)
1 parent 0029a6d commit 22d26eb

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

samgeo/text_sam.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class LangSAM:
109109
A Language-based Segment-Anything Model (LangSAM) class which combines GroundingDINO and SAM.
110110
"""
111111

112-
def __init__(self, model_type="vit_h"):
112+
def __init__(self, model_type="vit_h", checkpoint=None):
113113
"""Initialize the LangSAM instance.
114114
115115
Args:
@@ -119,7 +119,7 @@ def __init__(self, model_type="vit_h"):
119119

120120
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121121
self.build_groundingdino()
122-
self.build_sam(model_type)
122+
self.build_sam(model_type, checkpoint)
123123

124124
self.source = None
125125
self.image = None
@@ -129,17 +129,21 @@ def __init__(self, model_type="vit_h"):
129129
self.logits = None
130130
self.prediction = None
131131

132-
def build_sam(self, model_type):
132+
def build_sam(self, model_type, checkpoint_url=None):
133133
"""Build the SAM model.
134134
135135
Args:
136136
model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
137137
Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
138+
checkpoint_url:
138139
"""
139-
checkpoint_url = SAM_MODELS[model_type]
140-
sam = sam_model_registry[model_type]()
141-
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
142-
sam.load_state_dict(state_dict, strict=True)
140+
if checkpoint_url is not None:
141+
sam = sam_model_registry[model_type](checkpoint=checkpoint_url)
142+
else:
143+
checkpoint_url = SAM_MODELS[model_type]
144+
sam = sam_model_registry[model_type]()
145+
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
146+
sam.load_state_dict(state_dict, strict=True)
143147
sam.to(device=self.device)
144148
self.sam = SamPredictor(sam)
145149

0 commit comments

Comments
 (0)