Skip to content

Commit 55027d1

Browse files
Add the two preliminary fine-tuned models
1 parent bada8fa commit 55027d1

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

micro_sam/util.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,19 @@
2525
_MODEL_URLS = {
2626
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
2727
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
28-
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
28+
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
29+
# preliminary finetuned models
30+
"vit_h_lm": "https://owncloud.gwdg.de/index.php/s/CnxBvsdGPN0TD3A/download",
31+
"vit_b_lm": "https://owncloud.gwdg.de/index.php/s/gGlR1LFsav0eQ2k/download",
2932
}
3033
_CHECKPOINT_FOLDER = os.environ.get("SAM_MODELS", os.path.expanduser("~/.sam_models"))
3134
_CHECKSUMS = {
3235
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
3336
"vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622",
34-
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912"
37+
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912",
38+
# preliminary finetuned models
39+
"vit_h_lm": "c30a580e6ccaff2f4f0fbaf9cad10cee615a915cdd8c7bc4cb50ea9bdba3fc09",
40+
"vit_b_lm": "f2b8676f92a123f6f8ac998818118bd7269a559381ec60af4ac4be5c86024a1b",
3541
}
3642

3743

@@ -105,7 +111,13 @@ def get_sam_model(
105111
"""
106112
checkpoint = _get_checkpoint(model_type, checkpoint_path)
107113
device = "cuda" if torch.cuda.is_available() else "cpu"
108-
sam = sam_model_registry[model_type](checkpoint=checkpoint)
114+
115+
# Our custom model types have a suffix "_...". This suffix needs to be stripped
116+
# before calling sam_model_registry.
117+
model_type_ = model_type[:5]
118+
assert model_type_ in ("vit_h", "vit_b", "vit_l")
119+
120+
sam = sam_model_registry[model_type_](checkpoint=checkpoint)
109121
sam.to(device=device)
110122
predictor = SamPredictor(sam)
111123
if return_sam:

0 commit comments

Comments
 (0)