4848if is_transformers_available ():
4949 import transformers
5050 from transformers import PreTrainedModel , PreTrainedTokenizerBase
51- from transformers .utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
5251 from transformers .utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
5352 from transformers .utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
5453
54+ if is_transformers_version ("<=" , "4.56.2" ):
55+ from transformers .utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
56+
5557if is_accelerate_available ():
5658 import accelerate
5759 from accelerate import dispatch_model
@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
112114 ]
113115
114116 if is_transformers_available ():
115- weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
117+ weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME ]
118+ if is_transformers_version ("<=" , "4.56.2" ):
119+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME ]
116120
117121 # model_pytorch, diffusion_model_pytorch, ...
118122 weight_prefixes = [w .split ("." )[0 ] for w in weight_names ]
@@ -191,7 +195,9 @@ def filter_model_files(filenames):
191195 ]
192196
193197 if is_transformers_available ():
194- weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
198+ weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME ]
199+ if is_transformers_version ("<=" , "4.56.2" ):
200+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME ]
195201
196202 allowed_extensions = [wn .split ("." )[- 1 ] for wn in weight_names ]
197203
@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
212218 ]
213219
214220 if is_transformers_available ():
215- weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
221+ weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME ]
222+ if is_transformers_version ("<=" , "4.56.2" ):
223+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME ]
216224
217225 # model_pytorch, diffusion_model_pytorch, ...
218226 weight_prefixes = [w .split ("." )[0 ] for w in weight_names ]
0 commit comments