Skip to content

Commit afd115b

Browse files
committed
pipeline
1 parent 580b3c4 commit afd115b

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ def load_sub_model(
762762
dduf_entries: Optional[Dict[str, DDUFEntry]],
763763
provider_options: Any,
764764
quantization_config: Optional[Any] = None,
765+
use_flashpack: bool = False,
765766
):
766767
"""Helper method to load the module `name` from `library_name` and `class_name`"""
767768
from ..quantizers import PipelineQuantizationConfig
@@ -835,6 +836,9 @@ def load_sub_model(
835836
loading_kwargs["variant"] = model_variants.pop(name, None)
836837
loading_kwargs["use_safetensors"] = use_safetensors
837838

839+
if is_diffusers_model:
840+
loading_kwargs["use_flashpack"] = use_flashpack
841+
838842
if from_flax:
839843
loading_kwargs["from_flax"] = True
840844

@@ -881,7 +885,7 @@ def load_sub_model(
881885
# else load from the root directory
882886
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
883887

884-
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
888+
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict) and not use_flashpack:
885889
# remove hooks
886890
remove_hook_from_module(loaded_sub_model, recurse=True)
887891
needs_offloading_to_cpu = device_map[""] == "cpu"

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def save_pretrained(
241241
variant: Optional[str] = None,
242242
max_shard_size: Optional[Union[int, str]] = None,
243243
push_to_hub: bool = False,
244+
use_flashpack: bool = False,
244245
**kwargs,
245246
):
246247
"""
@@ -338,6 +339,7 @@ def is_saveable_module(name, value):
338339
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
339340
save_method_accept_variant = "variant" in save_method_signature.parameters
340341
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
342+
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
341343

342344
save_kwargs = {}
343345
if save_method_accept_safe:
@@ -347,6 +349,8 @@ def is_saveable_module(name, value):
347349
if save_method_accept_max_shard_size and max_shard_size is not None:
348350
# max_shard_size is expected to not be None in ModelMixin
349351
save_kwargs["max_shard_size"] = max_shard_size
352+
if save_method_accept_flashpack:
353+
save_kwargs["use_flashpack"] = use_flashpack
350354

351355
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
352356

@@ -758,6 +762,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
758762
use_onnx = kwargs.pop("use_onnx", None)
759763
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
760764
quantization_config = kwargs.pop("quantization_config", None)
765+
use_flashpack = kwargs.pop("use_flashpack", False)
761766

762767
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
763768
torch_dtype = torch.float32
@@ -1042,6 +1047,7 @@ def load_module(name, value):
10421047
dduf_entries=dduf_entries,
10431048
provider_options=provider_options,
10441049
quantization_config=quantization_config,
1050+
use_flashpack=use_flashpack,
10451051
)
10461052
logger.info(
10471053
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."

0 commit comments

Comments
 (0)