@@ -27,12 +27,12 @@ def read_file(blobpath: str) -> bytes:
27
27
return resp .content
28
28
29
29
30
- def check_hash (data : bytes , hash : str ) -> bool :
31
- data_hash = hashlib .sha256 (data ).hexdigest ()
32
- return data_hash == hash
30
+ def check_hash (data : bytes , expected_hash : str ) -> bool :
31
+ actual_hash = hashlib .sha256 (data ).hexdigest ()
32
+ return actual_hash == expected_hash
33
33
34
34
35
- def read_file_cached (blobpath : str , expected_hash : Optional [str ]= None ) -> bytes :
35
+ def read_file_cached (blobpath : str , expected_hash : Optional [str ] = None ) -> bytes :
36
36
user_specified_cache = True
37
37
if "TIKTOKEN_CACHE_DIR" in os .environ :
38
38
cache_dir = os .environ ["TIKTOKEN_CACHE_DIR" ]
@@ -52,13 +52,15 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
52
52
if os .path .exists (cache_path ):
53
53
with open (cache_path , "rb" ) as f :
54
54
data = f .read ()
55
- if expected_hash and not check_hash (data , expected_hash ):
56
- raise ValueError (
57
- f"Hash mismatch for cached data from { blobpath } (expected { expected_hash } ). "
58
- f"Please delete the cache file at { cache_path } and try again."
59
- )
55
+ if expected_hash is None or check_hash (data , expected_hash ):
60
56
return data
61
57
58
+ # the cached file does not match the hash, remove it and re-fetch
59
+ try :
60
+ os .remove (cache_path )
61
+ except OSError :
62
+ pass
63
+
62
64
contents = read_file (blobpath )
63
65
if expected_hash and not check_hash (contents , expected_hash ):
64
66
raise ValueError (
@@ -81,7 +83,10 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
81
83
82
84
83
85
def data_gym_to_mergeable_bpe_ranks (
84
- vocab_bpe_file : str , encoder_json_file : str , vocab_bpe_hash : Optional [str ]= None , encoder_json_hash : Optional [str ]= None
86
+ vocab_bpe_file : str ,
87
+ encoder_json_file : str ,
88
+ vocab_bpe_hash : Optional [str ] = None ,
89
+ encoder_json_hash : Optional [str ] = None ,
85
90
) -> dict [bytes , int ]:
86
91
# NB: do not add caching to this function
87
92
rank_to_intbyte = [b for b in range (2 ** 8 ) if chr (b ).isprintable () and chr (b ) != " " ]
@@ -135,7 +140,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
135
140
f .write (base64 .b64encode (token ) + b" " + str (rank ).encode () + b"\n " )
136
141
137
142
138
- def load_tiktoken_bpe (tiktoken_bpe_file : str , expected_hash : Optional [str ]= None ) -> dict [bytes , int ]:
143
+ def load_tiktoken_bpe (
144
+ tiktoken_bpe_file : str , expected_hash : Optional [str ] = None
145
+ ) -> dict [bytes , int ]:
139
146
# NB: do not add caching to this function
140
147
contents = read_file_cached (tiktoken_bpe_file , expected_hash )
141
148
return {
0 commit comments