Skip to content

Commit bb2e228

Browse files
committed
named_buffers
1 parent 3d56e94 commit bb2e228

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,13 +362,15 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
362362

363363
if is_accelerate_available():
364364
param_device = torch.device(device) if device else torch.device("cpu")
365+
named_buffers = model.named_buffers()
365366
unexpected_keys = load_model_dict_into_meta(
366367
model,
367368
diffusers_format_checkpoint,
368369
dtype=torch_dtype,
369370
device=param_device,
370371
hf_quantizer=hf_quantizer,
371372
keep_in_fp32_modules=keep_in_fp32_modules,
373+
named_buffers=named_buffers,
372374
)
373375

374376
else:

src/diffusers/models/model_loading_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from array import array
2121
from collections import OrderedDict
2222
from pathlib import Path
23-
from typing import List, Optional, Union
23+
from typing import Iterator, List, Optional, Tuple, Union
2424

2525
import safetensors
2626
import torch
@@ -185,6 +185,7 @@ def load_model_dict_into_meta(
185185
model_name_or_path: Optional[str] = None,
186186
hf_quantizer=None,
187187
keep_in_fp32_modules=None,
188+
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
188189
) -> List[str]:
189190
if device is not None and not isinstance(device, (str, torch.device)):
190191
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
@@ -246,7 +247,10 @@ def load_model_dict_into_meta(
246247
else:
247248
set_module_tensor_to_device(model, param_name, device, value=param)
248249

249-
for param_name, param in model.named_buffers():
250+
if named_buffers is None:
251+
return unexpected_keys
252+
253+
for param_name, param in named_buffers:
250254
if is_quantized and (
251255
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
252256
):

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
902902
" those weights or else make sure your checkpoint file is correct."
903903
)
904904

905+
named_buffers = model.named_buffers()
906+
905907
unexpected_keys = load_model_dict_into_meta(
906908
model,
907909
state_dict,
@@ -910,6 +912,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
910912
model_name_or_path=pretrained_model_name_or_path,
911913
hf_quantizer=hf_quantizer,
912914
keep_in_fp32_modules=keep_in_fp32_modules,
915+
named_buffers=named_buffers,
913916
)
914917

915918
if cls._keys_to_ignore_on_load_unexpected is not None:

0 commit comments

Comments
 (0)