Skip to content

Commit a6ffed6

Browse files
committed
FlashPack
1 parent 01a5692 commit a6ffed6

File tree

7 files changed

+100
-0
lines changed

7 files changed

+100
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def run(self):
248248
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
249249
extras["torchao"] = deps_list("torchao", "accelerate")
250250
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
251+
extras["flashpack"] = deps_list("flashpack")
251252

252253
if os.name == "nt": # windows
253254
extras["flax"] = [] # jax is not supported on windows

src/diffusers/models/model_loading_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ..quantizers import DiffusersQuantizer
3434
from ..utils import (
3535
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
36+
FLASHPACK_FILE_EXTENSION,
3637
GGUF_FILE_EXTENSION,
3738
SAFE_WEIGHTS_INDEX_NAME,
3839
SAFETENSORS_FILE_EXTENSION,
@@ -42,6 +43,7 @@
4243
deprecate,
4344
is_accelerate_available,
4445
is_accelerate_version,
46+
is_flashpack_available,
4547
is_gguf_available,
4648
is_torch_available,
4749
is_torch_version,
@@ -177,6 +179,8 @@ def load_state_dict(
177179
return safetensors.torch.load_file(checkpoint_file, device=map_location)
178180
elif file_extension == GGUF_FILE_EXTENSION:
179181
return load_gguf_checkpoint(checkpoint_file)
182+
elif file_extension == FLASHPACK_FILE_EXTENSION:
183+
return load_flashpack_checkpoint(checkpoint_file)
180184
else:
181185
extra_args = {}
182186
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
@@ -682,6 +686,33 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
682686
return parsed_parameters
683687

684688

689+
def load_flashpack_checkpoint(flashpack_checkpoint_path: str):
690+
"""
691+
Load a FlashPack file and return a dictionary of parsed parameters containing tensors.
692+
693+
Args:
694+
flashpack_checkpoint_path (`str`):
695+
The path the to FlashPack file to load
696+
"""
697+
698+
if is_flashpack_available() and is_torch_available():
699+
import flashpack
700+
else:
701+
logger.error(
702+
"Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
703+
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
704+
)
705+
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")
706+
707+
flash_tensor, meta = flashpack.deserialization.read_flashpack_file(
708+
path=flashpack_checkpoint_path,
709+
)
710+
state_dict = {}
711+
for name, view in flashpack.deserialization.iterate_from_flash_tensor(flash_tensor, meta):
712+
state_dict[name] = view
713+
return state_dict
714+
715+
685716
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
686717
mismatched_keys = []
687718
if not ignore_mismatched_sizes:

src/diffusers/models/modeling_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ..quantizers.quantization_config import QuantizationMethod
4343
from ..utils import (
4444
CONFIG_NAME,
45+
FLASHPACK_WEIGHTS_NAME,
4546
FLAX_WEIGHTS_NAME,
4647
HF_ENABLE_PARALLEL_LOADING,
4748
SAFE_WEIGHTS_INDEX_NAME,
@@ -55,6 +56,7 @@
5556
is_accelerate_available,
5657
is_bitsandbytes_available,
5758
is_bitsandbytes_version,
59+
is_flashpack_available,
5860
is_peft_available,
5961
is_torch_version,
6062
logging,
@@ -913,6 +915,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
913915
disable_mmap ('bool', *optional*, defaults to 'False'):
914916
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
915917
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
918+
use_flashpack (`bool`, *optional*, defaults to `False`):
919+
If set to `True`, the model is loaded from `flashpack` weights.
916920
917921
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
918922
with `hf > auth login`. You can also activate the special >
@@ -957,6 +961,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
957961
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
958962
disable_mmap = kwargs.pop("disable_mmap", False)
959963
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
964+
use_flashpack = kwargs.pop("use_flashpack", False)
960965

961966
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
962967
if is_parallel_loading_enabled and not low_cpu_mem_usage:
@@ -1185,6 +1190,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11851190
subfolder=subfolder or "",
11861191
dduf_entries=dduf_entries,
11871192
)
1193+
elif use_flashpack:
1194+
try:
1195+
resolved_model_file = _get_model_file(
1196+
pretrained_model_name_or_path,
1197+
weights_name=FLASHPACK_WEIGHTS_NAME,
1198+
cache_dir=cache_dir,
1199+
force_download=force_download,
1200+
proxies=proxies,
1201+
local_files_only=local_files_only,
1202+
token=token,
1203+
revision=revision,
1204+
subfolder=subfolder,
1205+
user_agent=user_agent,
1206+
commit_hash=commit_hash,
1207+
dduf_entries=dduf_entries,
1208+
)
1209+
1210+
except IOError as e:
1211+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
1212+
if not allow_pickle:
1213+
raise
1214+
logger.warning(
1215+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
1216+
)
11881217
elif use_safetensors:
11891218
try:
11901219
resolved_model_file = _get_model_file(
@@ -1248,6 +1277,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12481277
with ContextManagers(init_contexts):
12491278
model = cls.from_config(config, **unused_kwargs)
12501279

1280+
if use_flashpack:
1281+
if is_flashpack_available():
1282+
import flashpack
1283+
1284+
flashpack.mixin.assign_from_file(
1285+
model=model,
1286+
path=resolved_model_file[0],
1287+
device=None if device_map is None else device_map[""],
1288+
# silent=silent,
1289+
# strict=strict,
1290+
# strict_params=strict_params,
1291+
# strict_buffers=strict_buffers,
1292+
# keep_flash_ref_on_model=keep_flash_ref_on_model,
1293+
# num_streams=num_streams,
1294+
# chunk_bytes=chunk_bytes,
1295+
# ignore_names=ignore_names or cls.flashpack_ignore_names,
1296+
# ignore_prefixes=ignore_prefixes or cls.flashpack_ignore_prefixes,
1297+
# ignore_suffixes=ignore_suffixes or cls.flashpack_ignore_suffixes,
1298+
# use_distributed_loading=use_distributed_loading,
1299+
# rank=rank,
1300+
# local_rank=local_rank,
1301+
# world_size=world_size,
1302+
# coerce_dtype=coerce_dtype or cls.flashpack_coerce_dtype,
1303+
)
1304+
return model
1305+
12511306
if dtype_orig is not None:
12521307
torch.set_default_dtype(dtype_orig)
12531308

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,9 @@ def get_class_obj_and_candidates(
413413
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
414414
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
415415

416+
if class_name.startswith("FlashPack"):
417+
class_name = class_name.removeprefix("FlashPack")
418+
416419
if is_pipeline_module:
417420
pipeline_module = getattr(pipelines, library_name)
418421

src/diffusers/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
2424
DEPRECATED_REVISION_ARGS,
2525
DIFFUSERS_DYNAMIC_MODULE_NAME,
26+
FLASHPACK_FILE_EXTENSION,
27+
FLASHPACK_WEIGHTS_NAME,
2628
FLAX_WEIGHTS_NAME,
2729
GGUF_FILE_EXTENSION,
2830
HF_ENABLE_PARALLEL_LOADING,
@@ -74,6 +76,7 @@
7476
is_flash_attn_3_available,
7577
is_flash_attn_available,
7678
is_flash_attn_version,
79+
is_flashpack_available,
7780
is_flax_available,
7881
is_ftfy_available,
7982
is_gguf_available,

src/diffusers/utils/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
3535
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
3636
SAFETENSORS_FILE_EXTENSION = "safetensors"
37+
FLASHPACK_WEIGHTS_NAME = "model.flashpack"
38+
FLASHPACK_FILE_EXTENSION = "flashpack"
3739
GGUF_FILE_EXTENSION = "gguf"
3840
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
3941
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")

src/diffusers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
230230
_aiter_available, _aiter_version = _is_package_available("aiter")
231231
_kornia_available, _kornia_version = _is_package_available("kornia")
232232
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
233+
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
233234

234235

235236
def is_torch_available():
@@ -364,6 +365,10 @@ def is_gguf_available():
364365
return _gguf_available
365366

366367

368+
def is_flashpack_available():
369+
return _flashpack_available
370+
371+
367372
def is_torchao_available():
368373
return _torchao_available
369374

0 commit comments

Comments
 (0)