Skip to content

Commit 896da61

Browse files
Validate checksums when downloading
1 parent 63bc2f8 commit 896da61

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

micro_sam/util.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import os
23
from shutil import copyfileobj
34

@@ -28,9 +29,14 @@
2829
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
2930
}
3031
CHECKPOINT_FOLDER = os.environ.get("SAM_MODELS", os.path.expanduser("~/.sam_models"))
32+
CHECKSUMS = {
33+
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
34+
"vit_l": None, # TODO download vit_l and compute the checksum
35+
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912"
36+
}
3137

3238

33-
def _download(url, path):
39+
def _download(url, path, model_type):
3440
with requests.get(url, stream=True, verify=True) as r:
3541
if r.status_code != 200:
3642
r.raise_for_status()
@@ -42,6 +48,20 @@ def _download(url, path):
4248
with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
4349
copyfileobj(r_raw, f)
4450

51+
# validate the checksum
52+
expected_checksum = CHECKSUMS[model_type]
53+
if expected_checksum is None:
54+
return
55+
with open(path, "rb") as f:
56+
file_ = f.read()
57+
checksum = hashlib.sha256(file_).hexdigest()
58+
if checksum != expected_checksum:
59+
raise RuntimeError(
60+
"The checksum of the download does not match the expected checksum."
61+
f"Expected: {expected_checksum}, got: {checksum}"
62+
)
63+
print("Download successful and checksums agree.")
64+
4565

4666
def _get_checkpoint(model_type, checkpoint_path=None):
4767
if checkpoint_path is None:
@@ -52,7 +72,7 @@ def _get_checkpoint(model_type, checkpoint_path=None):
5272
# download the checkpoint if necessary
5373
if not os.path.exists(checkpoint_path):
5474
os.makedirs(CHECKPOINT_FOLDER, exist_ok=True)
55-
_download(checkpoint_url, checkpoint_path)
75+
_download(checkpoint_url, checkpoint_path, model_type)
5676
elif not os.path.exists(checkpoint_path):
5777
raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.")
5878

0 commit comments

Comments
 (0)