|
25 | 25 | _MODEL_URLS = { |
26 | 26 | "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", |
27 | 27 | "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", |
28 | | - "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" |
| 28 | + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", |
| 29 | + # preliminary finetuned models |
| 30 | + "vit_h_lm": "https://owncloud.gwdg.de/index.php/s/CnxBvsdGPN0TD3A/download", |
| 31 | + "vit_b_lm": "https://owncloud.gwdg.de/index.php/s/gGlR1LFsav0eQ2k/download", |
29 | 32 | } |
30 | 33 | _CHECKPOINT_FOLDER = os.environ.get("SAM_MODELS", os.path.expanduser("~/.sam_models")) |
31 | 34 | _CHECKSUMS = { |
32 | 35 | "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", |
33 | 36 | "vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", |
34 | | - "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912" |
| 37 | + "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", |
| 38 | + # preliminary finetuned models |
| 39 | + "vit_h_lm": "c30a580e6ccaff2f4f0fbaf9cad10cee615a915cdd8c7bc4cb50ea9bdba3fc09", |
| 40 | + "vit_b_lm": "f2b8676f92a123f6f8ac998818118bd7269a559381ec60af4ac4be5c86024a1b", |
35 | 41 | } |
36 | 42 |
|
37 | 43 |
|
@@ -105,7 +111,13 @@ def get_sam_model( |
105 | 111 | """ |
106 | 112 | checkpoint = _get_checkpoint(model_type, checkpoint_path) |
107 | 113 | device = "cuda" if torch.cuda.is_available() else "cpu" |
108 | | - sam = sam_model_registry[model_type](checkpoint=checkpoint) |
| 114 | + |
| 115 | + # Our custom model types have a suffix "_...". This suffix needs to be stripped |
| 116 | + # before calling sam_model_registry. |
| 117 | + model_type_ = model_type[:5] |
| 118 | + assert model_type_ in ("vit_h", "vit_b", "vit_l") |
| 119 | + |
| 120 | + sam = sam_model_registry[model_type_](checkpoint=checkpoint) |
109 | 121 | sam.to(device=device) |
110 | 122 | predictor = SamPredictor(sam) |
111 | 123 | if return_sam: |
|
0 commit comments