99from argparse import ArgumentParser
1010from pathlib import Path
1111import torch
12+ import re
1213
1314import deepspeed
1415from deepspeed .accelerator import get_accelerator
@@ -156,7 +157,7 @@ def get_repo_root(model_name_or_path):
156157 model_name_or_path ,
157158 local_files_only = is_offline_mode (),
158159 cache_dir = os .getenv ("TRANSFORMERS_CACHE" , None ),
159- ignore_patterns = ["*.safetensors" , "*. msgpack" , "*.h5" ],
160+ ignore_patterns = ["*.msgpack" , "*.h5" ],
160161 resume_download = True ,
161162 )
162163
@@ -166,17 +167,23 @@ def get_repo_root(model_name_or_path):
166167 model_name_or_path ,
167168 local_files_only = is_offline_mode (),
168169 cache_dir = os .getenv ("TRANSFORMERS_CACHE" , None ),
169- ignore_patterns = ["*.safetensors" , "*. msgpack" , "*.h5" ],
170+ ignore_patterns = ["*.msgpack" , "*.h5" ],
170171 resume_download = True ,
171172 )
172173
173174
174175def get_checkpoint_files (model_name_or_path ):
175176 cached_repo_dir = get_repo_root (model_name_or_path )
176177
177- # extensions: .bin | .pt
178+ # extensions: .bin | .pt | .safetensors
178179 # creates a list of paths from all downloaded files in cache dir
179- file_list = [str (entry ) for entry in Path (cached_repo_dir ).rglob ("*.[bp][it][n]" ) if entry .is_file ()]
180+ file_list = list ()
181+ pattern_sample = re .compile (r'(.*).(safetensors|bin|pt)$' )
182+ for entry in Path (cached_repo_dir ).rglob ("*" ):
183+ match = re .match (pattern = pattern_sample , string = str (entry ))
184+ if match :
185+ file_list .append (str (entry ))
186+
180187 return file_list
181188
182189
0 commit comments