@@ -916,6 +916,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
916916 If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
917917 `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
918918 weights. If set to `False`, `safetensors` weights are not loaded.
919+ use_flashpack (`bool`, *optional*, defaults to `False`):
920+ If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
921+ is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
922+ the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
923+ flashpack`.
919924 disable_mmap ('bool', *optional*, defaults to 'False'):
920925 Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
921926 is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -959,6 +964,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
959964 low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
960965 variant = kwargs .pop ("variant" , None )
961966 use_safetensors = kwargs .pop ("use_safetensors" , None )
967+ use_flashpack = kwargs .pop ("use_flashpack" , False )
962968 quantization_config = kwargs .pop ("quantization_config" , None )
963969 dduf_entries : Optional [Dict [str , DDUFEntry ]] = kwargs .pop ("dduf_entries" , None )
964970 disable_mmap = kwargs .pop ("disable_mmap" , False )
@@ -1177,6 +1183,72 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11771183
11781184 model = load_flax_checkpoint_in_pytorch_model (model , resolved_model_file )
11791185 else :
1186+ if use_flashpack :
1187+ try :
1188+ from flashpack import assign_from_file
1189+ except ImportError :
1190+ pass
1191+ else :
1192+ flashpack_weights_name = _add_variant ("model.flashpack" , variant )
1193+
1194+ try :
1195+ flashpack_file = _get_model_file (
1196+ pretrained_model_name_or_path ,
1197+ weights_name = flashpack_weights_name ,
1198+ cache_dir = cache_dir ,
1199+ force_download = force_download ,
1200+ proxies = proxies ,
1201+ local_files_only = local_files_only ,
1202+ token = token ,
1203+ revision = revision ,
1204+ subfolder = subfolder ,
1205+ user_agent = user_agent ,
1206+ commit_hash = commit_hash ,
1207+ )
1208+ except EnvironmentError :
1209+ pass
1210+ else :
1211+ dtype_orig = None
1212+ if torch_dtype is not None and torch_dtype != getattr (torch , "float8_e4m3fn" , None ):
1213+ if not isinstance (torch_dtype , torch .dtype ):
1214+ raise ValueError (
1215+ f"{ torch_dtype } needs to be a `torch.dtype`, e.g. `torch.float16`, but is { type (torch_dtype )} ."
1216+ )
1217+ dtype_orig = cls ._set_default_torch_dtype (torch_dtype )
1218+
1219+ with no_init_weights ():
1220+ model = cls .from_config (config , ** unused_kwargs )
1221+
1222+ if dtype_orig is not None :
1223+ torch .set_default_dtype (dtype_orig )
1224+
1225+ # flashpack requires a single dtype across all parameters
1226+ param_dtypes = {p .dtype for p in model .parameters ()}
1227+ if len (param_dtypes ) > 1 :
1228+ pass
1229+ else :
1230+ try :
1231+ assign_from_file (model , flashpack_file )
1232+ model .register_to_config (_name_or_path = pretrained_model_name_or_path )
1233+
1234+ if torch_dtype is not None and torch_dtype != getattr (torch , "float8_e4m3fn" , None ):
1235+ model = model .to (torch_dtype )
1236+
1237+ model .eval ()
1238+
1239+ if output_loading_info :
1240+ loading_info = {
1241+ "missing_keys" : [],
1242+ "unexpected_keys" : [],
1243+ "mismatched_keys" : [],
1244+ "error_msgs" : [],
1245+ }
1246+ return model , loading_info
1247+
1248+ return model
1249+
1250+ except Exception :
1251+ pass
11801252 # in the case it is sharded, we have already the index
11811253 if is_sharded :
11821254 resolved_model_file , sharded_metadata = _get_checkpoint_shard_files (
0 commit comments