Skip to content

Commit ee5cbdd

Browse files
authored
Merge branch 'main' into style-bot
2 parents f778226 + f5929e0 commit ee5cbdd

20 files changed

+933
-525
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -661,8 +661,20 @@ def set_adapters(
661661
adapter_names: Union[List[str], str],
662662
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
663663
):
664-
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
664+
if isinstance(adapter_weights, dict):
665+
components_passed = set(adapter_weights.keys())
666+
lora_components = set(self._lora_loadable_modules)
667+
668+
invalid_components = sorted(components_passed - lora_components)
669+
if invalid_components:
670+
logger.warning(
671+
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
672+
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
673+
"to the invalid components will be removed and ignored."
674+
)
675+
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
665676

677+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
666678
adapter_weights = copy.deepcopy(adapter_weights)
667679

668680
# Expand weights into a list, one entry per adapter
@@ -697,12 +709,6 @@ def set_adapters(
697709
for adapter_name, weights in zip(adapter_names, adapter_weights):
698710
if isinstance(weights, dict):
699711
component_adapter_weights = weights.pop(component, None)
700-
701-
if component_adapter_weights is not None and not hasattr(self, component):
702-
logger.warning(
703-
f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
704-
)
705-
706712
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
707713
logger.warning(
708714
(

src/diffusers/loaders/single_file.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from huggingface_hub import snapshot_download
2020
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
2121
from packaging import version
22+
from typing_extensions import Self
2223

2324
from ..utils import deprecate, is_transformers_available, logging
2425
from .single_file_utils import (
@@ -269,7 +270,7 @@ class FromSingleFileMixin:
269270

270271
@classmethod
271272
@validate_hf_hub_args
272-
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
273+
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
273274
r"""
274275
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
275276
format. The pipeline is set in evaluation mode (`model.eval()`) by default.

src/diffusers/loaders/single_file_model.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
from huggingface_hub.utils import validate_hf_hub_args
22+
from typing_extensions import Self
2223

2324
from ..quantizers import DiffusersAutoQuantizer
2425
from ..utils import deprecate, is_accelerate_available, logging
@@ -51,7 +52,7 @@
5152

5253

5354
if is_accelerate_available():
54-
from accelerate import init_empty_weights
55+
from accelerate import dispatch_model, init_empty_weights
5556

5657
from ..models.modeling_utils import load_model_dict_into_meta
5758

@@ -148,7 +149,7 @@ class FromOriginalModelMixin:
148149

149150
@classmethod
150151
@validate_hf_hub_args
151-
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
152+
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
152153
r"""
153154
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
154155
is set in evaluation mode (`model.eval()`) by default.
@@ -365,19 +366,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
365366
keep_in_fp32_modules=keep_in_fp32_modules,
366367
)
367368

369+
device_map = None
368370
if is_accelerate_available():
369371
param_device = torch.device(device) if device else torch.device("cpu")
370-
named_buffers = model.named_buffers()
371-
unexpected_keys = load_model_dict_into_meta(
372+
empty_state_dict = model.state_dict()
373+
unexpected_keys = [
374+
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
375+
]
376+
device_map = {"": param_device}
377+
load_model_dict_into_meta(
372378
model,
373379
diffusers_format_checkpoint,
374380
dtype=torch_dtype,
375-
device=param_device,
381+
device_map=device_map,
376382
hf_quantizer=hf_quantizer,
377383
keep_in_fp32_modules=keep_in_fp32_modules,
378-
named_buffers=named_buffers,
384+
unexpected_keys=unexpected_keys,
379385
)
380-
381386
else:
382387
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
383388

@@ -399,4 +404,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
399404

400405
model.eval()
401406

407+
if device_map is not None:
408+
device_map_kwargs = {"device_map": device_map}
409+
dispatch_model(model, **device_map_kwargs)
410+
402411
return model

src/diffusers/loaders/single_file_utils.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
15931593
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
15941594

15951595
if is_accelerate_available():
1596-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1596+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
15971597
else:
1598-
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1599-
1600-
if model._keys_to_ignore_on_load_unexpected is not None:
1601-
for pat in model._keys_to_ignore_on_load_unexpected:
1602-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1603-
1604-
if len(unexpected_keys) > 0:
1605-
logger.warning(
1606-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1607-
)
1598+
model.load_state_dict(diffusers_format_checkpoint, strict=False)
16081599

16091600
if torch_dtype is not None:
16101601
model.to(torch_dtype)
@@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
20612052
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
20622053

20632054
if is_accelerate_available():
2064-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2065-
if model._keys_to_ignore_on_load_unexpected is not None:
2066-
for pat in model._keys_to_ignore_on_load_unexpected:
2067-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
2068-
2069-
if len(unexpected_keys) > 0:
2070-
logger.warning(
2071-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
2072-
)
2073-
2055+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
20742056
else:
20752057
model.load_state_dict(diffusers_format_checkpoint)
20762058

0 commit comments

Comments
 (0)