From 570d811e793b908edfebed227490e95e329cecb2 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 10 Jan 2025 10:26:23 +0000 Subject: [PATCH 1/3] Move buffers to device --- src/diffusers/models/model_loading_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index a3d006f18994..c7ce9d0c0ffc 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -246,6 +246,17 @@ def load_model_dict_into_meta( else: set_module_tensor_to_device(model, param_name, device, value=param) + for param_name, param in model.named_buffers(): + if is_quantized and ( + hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) + ): + hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) + else: + if accepts_dtype: + set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) + else: + set_module_tensor_to_device(model, param_name, device, value=param) + return unexpected_keys From 3d56e94b7248a32998c1fe347584d14f04518507 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 10 Jan 2025 11:55:25 +0000 Subject: [PATCH 2/3] add test --- tests/quantization/bnb/test_mixed_int8.py | 36 ++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index f474a1d4f4d0..9c727d0fcde4 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -19,7 +19,14 @@ import numpy as np import pytest -from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging +from diffusers import ( + BitsAndBytesConfig, + DiffusionPipeline, + FluxTransformer2DModel, + SanaTransformer2DModel, + SD3Transformer2DModel, + logging, +) from diffusers.utils import is_accelerate_version from diffusers.utils.testing_utils import ( CaptureLogger, @@ -300,6 +307,33 @@ def test_device_and_dtype_assignment(self): _ = self.model_fp16.cuda() +class Bnb8bitDeviceTests(Base8bitTests): + def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SanaTransformer2DModel.from_pretrained( + "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers", + subfolder="transformer", + quantization_config=mixed_int8_config, + ) + + def tearDown(self): + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_buffers_device_assignment(self): + for buffer_name, buffer in self.model_8bit.named_buffers(): + self.assertEqual( + buffer.device.type, + torch.device(torch_device).type, + f"Expected device {torch_device} for {buffer_name} got {buffer.device}.", + ) + + class BnB8bitTrainingTests(Base8bitTests): def setUp(self): gc.collect() From bb2e2281d10105c5aa0caa5374d4a056494190fd Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 14 Jan 2025 07:19:04 +0000 Subject: [PATCH 3/3] named_buffers --- src/diffusers/loaders/single_file_model.py | 2 ++ src/diffusers/models/model_loading_utils.py | 8 ++++++-- src/diffusers/models/modeling_utils.py | 3 +++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 69ab8b6bad20..c7d0fcb3046e 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -362,6 +362,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = if is_accelerate_available(): param_device = torch.device(device) if device else torch.device("cpu") + named_buffers = model.named_buffers() unexpected_keys = load_model_dict_into_meta( model, diffusers_format_checkpoint, @@ -369,6 +370,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = device=param_device, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, + named_buffers=named_buffers, ) else: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index c7ce9d0c0ffc..67af76753bef 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -20,7 +20,7 @@ from array import array from collections import OrderedDict from pathlib import Path -from typing import List, Optional, Union +from typing import Iterator, List, Optional, Tuple, Union import safetensors import torch @@ -185,6 +185,7 @@ def load_model_dict_into_meta( model_name_or_path: Optional[str] = None, hf_quantizer=None, keep_in_fp32_modules=None, + named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None, ) -> List[str]: if device is not None and not isinstance(device, (str, torch.device)): raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") @@ -246,7 +247,10 @@ def load_model_dict_into_meta( else: set_module_tensor_to_device(model, param_name, device, value=param) - for param_name, param in model.named_buffers(): + if named_buffers is None: + return unexpected_keys + + for param_name, param in named_buffers: if is_quantized and ( hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) ): diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 17e9d2043150..9b3553e8a67d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -902,6 +902,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " those weights or else make sure your checkpoint file is correct." ) + named_buffers = model.named_buffers() + unexpected_keys = load_model_dict_into_meta( model, state_dict, @@ -910,6 +912,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model_name_or_path=pretrained_model_name_or_path, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, + named_buffers=named_buffers, ) if cls._keys_to_ignore_on_load_unexpected is not None: