Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
if is_accelerate_available():
from accelerate import dispatch_model, init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta


SINGLE_FILE_LOADABLE_CLASSES = {
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
if is_accelerate_available():
from accelerate import init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down
125 changes: 125 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
from array import array
from collections import OrderedDict, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
Expand Down Expand Up @@ -310,6 +311,130 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index


def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it here from modeling_utils.py.

"""
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
parameters.

"""
if model_to_load.device.type == "meta":
return False

if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False

# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False

# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype

return False


def load_shard_file(args):
(
model,
model_state_dict,
shard_file,
device_map,
dtype,
hf_quantizer,
keep_in_fp32_modules,
dduf_entries,
loaded_keys,
unexpected_keys,
offload_index,
offload_folder,
state_dict_index,
state_dict_folder,
ignore_mismatched_sizes,
low_cpu_mem_usage,
) = args
assign_to_params_buffers = None
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
)
error_msgs = []
if low_cpu_mem_usage:
offload_index, state_dict_index = load_model_dict_into_meta(
model,
state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
)
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)

error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
return offload_index, state_dict_index, mismatched_keys, error_msgs


def load_shard_files_with_threadpool(args_list):
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add HF_PARALLEL_LOADING_WORKERS as a constant at the top of the file for consistency.


# Do not spawn anymore workers than you need
num_workers = min(len(args_list), num_workers)

logger.info(f"Loading model weights in parallel with {num_workers} workers...")

error_msgs = []
mismatched_keys = []

with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
for future in as_completed(futures):
result = future.result()
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys
pbar.update(1)

return offload_index, state_dict_index, mismatched_keys, error_msgs


def _find_mismatched_keys(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Moved it out of modeling_utils.py.

state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue

if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys


def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
Expand Down
109 changes: 49 additions & 60 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
ENV_VARS_TRUE_VALUES,
WEIGHTS_NAME,
_add_variant,
_get_checkpoint_shard_files,
Expand All @@ -69,9 +70,8 @@
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_find_mismatched_keys,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_shard_file,
load_shard_files_with_threadpool,
load_state_dict,
)

Expand Down Expand Up @@ -208,34 +208,6 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return last_tuple[1].dtype


def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
parameters.

"""
if model_to_load.device.type == "meta":
return False

if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False

# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False

# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype

return False


@contextmanager
def no_init_weights():
"""
Expand Down Expand Up @@ -988,6 +960,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)

is_parallel_loading_enabled = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES

if is_parallel_loading_enabled and not low_cpu_mem_usage:
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
Expand Down Expand Up @@ -1323,6 +1300,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
)
loading_info = {
"missing_keys": missing_keys,
Expand Down Expand Up @@ -1518,6 +1496,7 @@ def _load_pretrained_model(
offload_state_dict: Optional[bool] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
is_parallel_loading_enabled: Optional[bool] = False,
):
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
Expand All @@ -1531,6 +1510,9 @@ def _load_pretrained_model(
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

mismatched_keys = []
error_msgs = []

# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
Expand Down Expand Up @@ -1565,38 +1547,45 @@ def _load_pretrained_model(
# load_state_dict will manage the case where we pass a dict instead of a file
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
resolved_model_file = [state_dict]
is_file = resolved_model_file and state_dict is None

# prepare the arguments.
args_list = [
(
model,
model_state_dict,
shard_file,
device_map,
dtype,
hf_quantizer,
keep_in_fp32_modules,
dduf_entries,
loaded_keys,
unexpected_keys,
offload_index,
offload_folder,
state_dict_index,
state_dict_folder,
ignore_mismatched_sizes,
low_cpu_mem_usage,
)
for shard_file in resolved_model_file
]
Copy link
Collaborator

@DN6 DN6 Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the same arguments are used across the two loading functions, it's a good candidate for functools.partial

        load_fn = partial(
            load_shard_files_with_threadpool if is_parallel_loading_enabled else load_shard_file,
            model=model,
            model_state_dict=model_state_dict,
            device_map=device_map,
            dtype=dtype,
            hf_quantizer=hf_quantizer,
            keep_in_fp32_modules=keep_in_fp32_modules,
            dduf_entries=dduf_entries,
            loaded_keys=loaded_keys,
            unexpected_keys=unexpected_keys,
            offload_index=offload_index,
            offload_folder=offload_folder,
            state_dict_index=state_dict_index,
            state_dict_folder=state_dict_folder,
            ignore_mismatched_sizes=ignore_mismatched_sizes,
            low_cpu_mem_usage=low_cpu_mem_usage,
        )

        if is_parallel_loading_enabled:
            offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(
                resolved_model_file,
            )
            error_msgs += _error_msgs
            mismatched_keys += _mismatched_keys
        else:
            shard_files = resolved_model_file
            if len(resolved_model_file) > 1:
                shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")

            for shard_file in resolved_model_file:
                offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
                error_msgs += _error_msgs
                mismatched_keys += _mismatched_keys


if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")

mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []

for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
mismatched_keys += _find_mismatched_keys(
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
Comment on lines -1569 to -1579
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been moved to load_shard_file().

if is_parallel_loading_enabled and is_file:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool(
args_list
)
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys
else:
if len(args_list) > 1:
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")

if low_cpu_mem_usage:
offload_index, state_dict_index = load_model_dict_into_meta(
model,
state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
)
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
for args in args_list:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_file(args)
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys

empty_device_cache()

Expand Down
Loading