Skip to content

Commit 7ba7f0b

Browse files
committed
flashpack_kwargs
1 parent e4d1553 commit 7ba7f0b

File tree

1 file changed

+5
-15
lines changed

1 file changed

+5
-15
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
939939
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
940940
use_flashpack (`bool`, *optional*, defaults to `False`):
941941
If set to `True`, the model is loaded from `flashpack` weights.
942+
flashpack_kwargs(`dict[str, Any]`, *optional*, defaults to `{}`):
943+
Kwargs passed to [`flashpack.deserialization.assign_from_file`](https://github.com/fal-ai/flashpack/blob/f1aa91c5cd9532a3dbf5bcc707ab9b01c274b76c/src/flashpack/deserialization.py#L408-L422)
944+
942945
943946
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
944947
with `hf > auth login`. You can also activate the special >
@@ -984,6 +987,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
984987
disable_mmap = kwargs.pop("disable_mmap", False)
985988
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
986989
use_flashpack = kwargs.pop("use_flashpack", False)
990+
flashpack_kwargs = kwargs.pop("flashpack_kwargs", {})
987991

988992
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
989993
if is_parallel_loading_enabled and not low_cpu_mem_usage:
@@ -1326,21 +1330,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
13261330
model=model,
13271331
path=resolved_model_file[0],
13281332
device=flashpack_device,
1329-
# silent=silent,
1330-
# strict=strict,
1331-
# strict_params=strict_params,
1332-
# strict_buffers=strict_buffers,
1333-
# keep_flash_ref_on_model=keep_flash_ref_on_model,
1334-
# num_streams=num_streams,
1335-
# chunk_bytes=chunk_bytes,
1336-
# ignore_names=ignore_names or cls.flashpack_ignore_names,
1337-
# ignore_prefixes=ignore_prefixes or cls.flashpack_ignore_prefixes,
1338-
# ignore_suffixes=ignore_suffixes or cls.flashpack_ignore_suffixes,
1339-
# use_distributed_loading=use_distributed_loading,
1340-
# rank=rank,
1341-
# local_rank=local_rank,
1342-
# world_size=world_size,
1343-
# coerce_dtype=coerce_dtype or cls.flashpack_coerce_dtype,
1333+
**flashpack_kwargs,
13441334
)
13451335

13461336
if output_loading_info:

0 commit comments

Comments
 (0)