Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,15 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =

if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)

else:
Expand Down
17 changes: 16 additions & 1 deletion src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Tuple, Union

import safetensors
import torch
Expand Down Expand Up @@ -193,6 +193,7 @@ def load_model_dict_into_meta(
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
Expand Down Expand Up @@ -254,6 +255,20 @@ def load_model_dict_into_meta(
else:
set_module_tensor_to_device(model, param_name, device, value=param)

if named_buffers is None:
return unexpected_keys

for param_name, param in named_buffers:
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)

return unexpected_keys


Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" those weights or else make sure your checkpoint file is correct."
)

named_buffers = model.named_buffers()

unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
Expand All @@ -921,6 +923,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model_name_or_path=pretrained_model_name_or_path,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)

if cls._keys_to_ignore_on_load_unexpected is not None:
Expand Down
36 changes: 35 additions & 1 deletion tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
import pytest
from huggingface_hub import hf_hub_download

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
logging,
)
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
Expand Down Expand Up @@ -302,6 +309,33 @@ def test_device_and_dtype_assignment(self):
_ = self.model_fp16.cuda()


class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()

mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
subfolder="transformer",
quantization_config=mixed_int8_config,
)

def tearDown(self):
del self.model_8bit

gc.collect()
torch.cuda.empty_cache()

def test_buffers_device_assignment(self):
for buffer_name, buffer in self.model_8bit.named_buffers():
self.assertEqual(
buffer.device.type,
torch.device(torch_device).type,
f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
)


class BnB8bitTrainingTests(Base8bitTests):
def setUp(self):
gc.collect()
Expand Down