Skip to content

Commit 73e5897

Browse files
add flashpack support with safetensors fallback
1 parent 544ba67 commit 73e5897

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
916916
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
917917
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
918918
weights. If set to `False`, `safetensors` weights are not loaded.
919+
use_flashpack (`bool`, *optional*, defaults to `False`):
920+
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
921+
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
922+
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
923+
flashpack`.
919924
disable_mmap ('bool', *optional*, defaults to 'False'):
920925
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
921926
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -959,6 +964,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
959964
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
960965
variant = kwargs.pop("variant", None)
961966
use_safetensors = kwargs.pop("use_safetensors", None)
967+
use_flashpack = kwargs.pop("use_flashpack", False)
962968
quantization_config = kwargs.pop("quantization_config", None)
963969
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
964970
disable_mmap = kwargs.pop("disable_mmap", False)
@@ -1177,6 +1183,72 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11771183

11781184
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
11791185
else:
1186+
if use_flashpack:
1187+
try:
1188+
from flashpack import assign_from_file
1189+
except ImportError:
1190+
pass
1191+
else:
1192+
flashpack_weights_name = _add_variant("model.flashpack", variant)
1193+
1194+
try:
1195+
flashpack_file = _get_model_file(
1196+
pretrained_model_name_or_path,
1197+
weights_name=flashpack_weights_name,
1198+
cache_dir=cache_dir,
1199+
force_download=force_download,
1200+
proxies=proxies,
1201+
local_files_only=local_files_only,
1202+
token=token,
1203+
revision=revision,
1204+
subfolder=subfolder,
1205+
user_agent=user_agent,
1206+
commit_hash=commit_hash,
1207+
)
1208+
except EnvironmentError:
1209+
pass
1210+
else:
1211+
dtype_orig = None
1212+
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
1213+
if not isinstance(torch_dtype, torch.dtype):
1214+
raise ValueError(
1215+
f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1216+
)
1217+
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
1218+
1219+
with no_init_weights():
1220+
model = cls.from_config(config, **unused_kwargs)
1221+
1222+
if dtype_orig is not None:
1223+
torch.set_default_dtype(dtype_orig)
1224+
1225+
# flashpack requires a single dtype across all parameters
1226+
param_dtypes = {p.dtype for p in model.parameters()}
1227+
if len(param_dtypes) > 1:
1228+
pass
1229+
else:
1230+
try:
1231+
assign_from_file(model, flashpack_file)
1232+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1233+
1234+
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
1235+
model = model.to(torch_dtype)
1236+
1237+
model.eval()
1238+
1239+
if output_loading_info:
1240+
loading_info = {
1241+
"missing_keys": [],
1242+
"unexpected_keys": [],
1243+
"mismatched_keys": [],
1244+
"error_msgs": [],
1245+
}
1246+
return model, loading_info
1247+
1248+
return model
1249+
1250+
except Exception:
1251+
pass
11801252
# in the case it is sharded, we have already the index
11811253
if is_sharded:
11821254
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,7 @@ def load_sub_model(
756756
low_cpu_mem_usage: bool,
757757
cached_folder: Union[str, os.PathLike],
758758
use_safetensors: bool,
759+
use_flashpack: bool,
759760
dduf_entries: Optional[Dict[str, DDUFEntry]],
760761
provider_options: Any,
761762
quantization_config: Optional[Any] = None,
@@ -832,6 +833,9 @@ def load_sub_model(
832833
loading_kwargs["variant"] = model_variants.pop(name, None)
833834
loading_kwargs["use_safetensors"] = use_safetensors
834835

836+
if is_diffusers_model:
837+
loading_kwargs["use_flashpack"] = use_flashpack
838+
835839
if from_flax:
836840
loading_kwargs["from_flax"] = True
837841

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
693693
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
694694
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
695695
weights. If set to `False`, safetensors weights are not loaded.
696+
use_flashpack (`bool`, *optional*, defaults to `False`):
697+
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
698+
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
699+
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
700+
flashpack`.
696701
use_onnx (`bool`, *optional*, defaults to `None`):
697702
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
698703
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
@@ -755,6 +760,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
755760
variant = kwargs.pop("variant", None)
756761
dduf_file = kwargs.pop("dduf_file", None)
757762
use_safetensors = kwargs.pop("use_safetensors", None)
763+
use_flashpack = kwargs.pop("use_flashpack", False)
758764
use_onnx = kwargs.pop("use_onnx", None)
759765
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
760766
quantization_config = kwargs.pop("quantization_config", None)
@@ -1039,6 +1045,7 @@ def load_module(name, value):
10391045
low_cpu_mem_usage=low_cpu_mem_usage,
10401046
cached_folder=cached_folder,
10411047
use_safetensors=use_safetensors,
1048+
use_flashpack=use_flashpack,
10421049
dduf_entries=dduf_entries,
10431050
provider_options=provider_options,
10441051
quantization_config=quantization_config,

0 commit comments

Comments
 (0)