9292    ALL_IMPORTABLE_CLASSES .update (LOADABLE_CLASSES [library ])
9393
9494
95- def  is_safetensors_compatible (filenames , passed_components = None , folder_names = None ) ->  bool :
95+ def  is_safetensors_compatible (filenames , passed_components = None , folder_names = None ,  variant = None ) ->  bool :
9696    """ 
9797    Checking for safetensors compatibility: 
9898    - The model is safetensors compatible only if there is a safetensors file for each model component present in 
@@ -103,6 +103,28 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
103103    - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" 
104104      extension is replaced with ".safetensors" 
105105    """ 
106+     weight_names  =  [
107+         WEIGHTS_NAME ,
108+         SAFETENSORS_WEIGHTS_NAME ,
109+         FLAX_WEIGHTS_NAME ,
110+         ONNX_WEIGHTS_NAME ,
111+         ONNX_EXTERNAL_WEIGHTS_NAME ,
112+     ]
113+ 
114+     if  is_transformers_available ():
115+         weight_names  +=  [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
116+ 
117+     # model_pytorch, diffusion_model_pytorch, ... 
118+     weight_prefixes  =  [w .split ("." )[0 ] for  w  in  weight_names ]
119+     # .bin, .safetensors, ... 
120+     weight_suffixs  =  [w .split ("." )[- 1 ] for  w  in  weight_names ]
121+     # -00001-of-00002 
122+     transformers_index_format  =  r"\d{5}-of-\d{5}" 
123+     # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` 
124+     non_variant_file_re  =  re .compile (
125+         rf"({ '|' .join (weight_prefixes )} { transformers_index_format } { '|' .join (weight_suffixs )}  
126+     )
127+ 
106128    passed_components  =  passed_components  or  []
107129    if  folder_names :
108130        filenames  =  {f  for  f  in  filenames  if  os .path .split (f )[0 ] in  folder_names }
@@ -130,6 +152,8 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
130152    for  component , component_filenames  in  components .items ():
131153        matches  =  []
132154        for  component_filename  in  component_filenames :
155+             if  variant  is  None :
156+                 component_filename  =  filter_with_regex (component_filename , non_variant_file_re )
133157            filename , extension  =  os .path .splitext (component_filename )
134158
135159            match_exists  =  extension  ==  ".safetensors" 
@@ -158,6 +182,8 @@ def filter_model_files(filenames):
158182
159183    return  [f  for  f  in  filenames  if  any (f .endswith (extension ) for  extension  in  allowed_extensions )]
160184
185+ def  filter_with_regex (filenames , pattern_re ):
186+     return  {f  for  f  in  filenames  if  pattern_re .match (f .split ("/" )[- 1 ]) is  not None }
161187
162188def  variant_compatible_siblings (filenames , variant = None , ignore_patterns = None ) ->  Union [List [os .PathLike ], str ]:
163189    weight_names  =  [
@@ -207,9 +233,6 @@ def filter_for_compatible_extensions(filenames, ignore_patterns=None):
207233        # interested in the extension name 
208234        return  {f  for  f  in  filenames  if  not  any (f .endswith (pat .lstrip ("*." )) for  pat  in  ignore_patterns )}
209235
210-     def  filter_with_regex (filenames , pattern_re ):
211-         return  {f  for  f  in  filenames  if  pattern_re .match (f .split ("/" )[- 1 ]) is  not None }
212- 
213236    # Group files by component 
214237    components  =  {}
215238    for  filename  in  filenames :
@@ -997,7 +1020,7 @@ def _get_ignore_patterns(
9971020        use_safetensors 
9981021        and  not  allow_pickle 
9991022        and  not  is_safetensors_compatible (
1000-             model_filenames , passed_components = passed_components , folder_names = model_folder_names 
1023+             model_filenames , passed_components = passed_components , folder_names = model_folder_names ,  variant = variant 
10011024        )
10021025    ):
10031026        raise  EnvironmentError (
@@ -1008,7 +1031,7 @@ def _get_ignore_patterns(
10081031        ignore_patterns  =  ["*.bin" , "*.safetensors" , "*.onnx" , "*.pb" ]
10091032
10101033    elif  use_safetensors  and  is_safetensors_compatible (
1011-         model_filenames , passed_components = passed_components , folder_names = model_folder_names 
1034+         model_filenames , passed_components = passed_components , folder_names = model_folder_names ,  variant = variant 
10121035    ):
10131036        ignore_patterns  =  ["*.bin" , "*.msgpack" ]
10141037
0 commit comments