Skip to content

Commit 681cb96

Browse files
authored
Fixed model download bug (#136)
1 parent 2f4a5ef commit 681cb96

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

docs/examples/text_prompts_batch.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,16 @@
198198
"outputs": [],
199199
"source": [
200200
"sam.predict_batch(\n",
201-
" images='tiles', \n",
202-
" out_dir='masks', \n",
203-
" text_prompt=text_prompt, \n",
204-
" box_threshold=0.24, \n",
201+
" images='tiles',\n",
202+
" out_dir='masks',\n",
203+
" text_prompt=text_prompt,\n",
204+
" box_threshold=0.24,\n",
205205
" text_threshold=0.24,\n",
206-
" mask_multiplier=255, \n",
206+
" mask_multiplier=255,\n",
207207
" dtype='uint8',\n",
208208
" merge=True,\n",
209-
" verbose=True\n",
210-
" )"
209+
" verbose=True,\n",
210+
")"
211211
]
212212
},
213213
{

samgeo/samgeo.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class SamGeo:
1919
def __init__(
2020
self,
2121
model_type="vit_h",
22-
checkpoint="sam_vit_h_4b8939.pth",
22+
checkpoint=None,
2323
automatic=True,
2424
device=None,
2525
sam_kwargs=None,
@@ -31,7 +31,7 @@ def __init__(
3131
Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
3232
checkpoint (str, optional): The path to the model checkpoint. It can be one of the following:
3333
sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
34-
Defaults to "sam_vit_h_4b8939.pth". See https://bit.ly/3VrpxUh for more details.
34+
Defaults to None. See https://bit.ly/3VrpxUh for more details.
3535
automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
3636
The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
3737
device (str, optional): The device to use. It can be one of the following: cpu, cuda.
@@ -58,12 +58,36 @@ def __init__(
5858
CACHE_PATH = os.environ.get(
5959
"TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints")
6060
)
61+
62+
model_types = {
63+
"vit_h": "sam_vit_h_4b8939.pth",
64+
"vit_l": "sam_vit_l_0b3195.pth",
65+
"vit_b": "sam_vit_b_01ec64.pth",
66+
}
67+
68+
if model_type not in model_types:
69+
raise ValueError(
70+
f"Model type {model_type} is not supported. It must be one of the following: {model_types}."
71+
)
72+
73+
checkpoints = {
74+
"sam_vit_h_4b8939.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
75+
"sam_vit_l_0b3195.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
76+
"sam_vit_b_01ec64.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
77+
}
78+
79+
if checkpoint is None:
80+
url = checkpoints[model_types[model_type]]
81+
checkpoint = model_types[model_type]
82+
else:
83+
url = None
84+
6185
if not os.path.exists(checkpoint):
6286
basename = os.path.basename(checkpoint)
6387
checkpoint = os.path.join(CACHE_PATH, basename)
6488
if not os.path.exists(checkpoint):
6589
print(f"Checkpoint {checkpoint} does not exist.")
66-
download_checkpoint(output=checkpoint)
90+
download_checkpoint(url=url, output=checkpoint)
6791
self.checkpoint = checkpoint
6892

6993
# Use cuda if available

samgeo/text_sam.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,9 @@ def predict_batch(
374374
for i, image in enumerate(images):
375375
basename = os.path.splitext(os.path.basename(image))[0]
376376
if verbose:
377-
print(f"Processing image {str(i+1).zfill(len(str(len(images))))} of {len(images)}: {image}...")
377+
print(
378+
f"Processing image {str(i+1).zfill(len(str(len(images))))} of {len(images)}: {image}..."
379+
)
378380
output = os.path.join(out_dir, f"{basename}_mask.tif")
379381
self.predict(
380382
image,

0 commit comments

Comments
 (0)