1+ import hashlib
12import os
23from shutil import copyfileobj
34
2829 "vit_b" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
2930}
3031CHECKPOINT_FOLDER = os .environ .get ("SAM_MODELS" , os .path .expanduser ("~/.sam_models" ))
32+ CHECKSUMS = {
33+ "vit_h" : "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e" ,
34+ "vit_l" : None , # TODO download vit_l and compute the checksum
35+ "vit_b" : "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912"
36+ }
3137
3238
33- def _download (url , path ):
39+ def _download (url , path , model_type ):
3440 with requests .get (url , stream = True , verify = True ) as r :
3541 if r .status_code != 200 :
3642 r .raise_for_status ()
@@ -42,6 +48,20 @@ def _download(url, path):
4248 with tqdm .wrapattr (r .raw , "read" , total = file_size , desc = desc ) as r_raw , open (path , "wb" ) as f :
4349 copyfileobj (r_raw , f )
4450
51+ # validate the checksum
52+ expected_checksum = CHECKSUMS [model_type ]
53+ if expected_checksum is None :
54+ return
55+ with open (path , "rb" ) as f :
56+ file_ = f .read ()
57+ checksum = hashlib .sha256 (file_ ).hexdigest ()
58+ if checksum != expected_checksum :
59+ raise RuntimeError (
60+ "The checksum of the download does not match the expected checksum."
61+ f"Expected: { expected_checksum } , got: { checksum } "
62+ )
63+ print ("Download successful and checksums agree." )
64+
4565
4666def _get_checkpoint (model_type , checkpoint_path = None ):
4767 if checkpoint_path is None :
@@ -52,7 +72,7 @@ def _get_checkpoint(model_type, checkpoint_path=None):
5272 # download the checkpoint if necessary
5373 if not os .path .exists (checkpoint_path ):
5474 os .makedirs (CHECKPOINT_FOLDER , exist_ok = True )
55- _download (checkpoint_url , checkpoint_path )
75+ _download (checkpoint_url , checkpoint_path , model_type )
5676 elif not os .path .exists (checkpoint_path ):
5777 raise ValueError (f"The checkpoint path { checkpoint_path } that was passed does not exist." )
5878
0 commit comments