Skip to content

Commit 0b065c0

Browse files
authored
Move buffers to device (#10523)
* Move buffers to device * add test * named_buffers
1 parent b785ddb commit 0b065c0

File tree

4 files changed

+56
-2
lines changed

4 files changed

+56
-2
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,13 +362,15 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
362362

363363
if is_accelerate_available():
364364
param_device = torch.device(device) if device else torch.device("cpu")
365+
named_buffers = model.named_buffers()
365366
unexpected_keys = load_model_dict_into_meta(
366367
model,
367368
diffusers_format_checkpoint,
368369
dtype=torch_dtype,
369370
device=param_device,
370371
hf_quantizer=hf_quantizer,
371372
keep_in_fp32_modules=keep_in_fp32_modules,
373+
named_buffers=named_buffers,
372374
)
373375

374376
else:

src/diffusers/models/model_loading_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from array import array
2121
from collections import OrderedDict
2222
from pathlib import Path
23-
from typing import Dict, List, Optional, Union
23+
from typing import Dict, Iterator, List, Optional, Tuple, Union
2424

2525
import safetensors
2626
import torch
@@ -193,6 +193,7 @@ def load_model_dict_into_meta(
193193
model_name_or_path: Optional[str] = None,
194194
hf_quantizer=None,
195195
keep_in_fp32_modules=None,
196+
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
196197
) -> List[str]:
197198
if device is not None and not isinstance(device, (str, torch.device)):
198199
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
@@ -254,6 +255,20 @@ def load_model_dict_into_meta(
254255
else:
255256
set_module_tensor_to_device(model, param_name, device, value=param)
256257

258+
if named_buffers is None:
259+
return unexpected_keys
260+
261+
for param_name, param in named_buffers:
262+
if is_quantized and (
263+
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
264+
):
265+
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
266+
else:
267+
if accepts_dtype:
268+
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
269+
else:
270+
set_module_tensor_to_device(model, param_name, device, value=param)
271+
257272
return unexpected_keys
258273

259274

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
913913
" those weights or else make sure your checkpoint file is correct."
914914
)
915915

916+
named_buffers = model.named_buffers()
917+
916918
unexpected_keys = load_model_dict_into_meta(
917919
model,
918920
state_dict,
@@ -921,6 +923,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
921923
model_name_or_path=pretrained_model_name_or_path,
922924
hf_quantizer=hf_quantizer,
923925
keep_in_fp32_modules=keep_in_fp32_modules,
926+
named_buffers=named_buffers,
924927
)
925928

926929
if cls._keys_to_ignore_on_load_unexpected is not None:

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@
2020
import pytest
2121
from huggingface_hub import hf_hub_download
2222

23-
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
23+
from diffusers import (
24+
BitsAndBytesConfig,
25+
DiffusionPipeline,
26+
FluxTransformer2DModel,
27+
SanaTransformer2DModel,
28+
SD3Transformer2DModel,
29+
logging,
30+
)
2431
from diffusers.utils import is_accelerate_version
2532
from diffusers.utils.testing_utils import (
2633
CaptureLogger,
@@ -302,6 +309,33 @@ def test_device_and_dtype_assignment(self):
302309
_ = self.model_fp16.cuda()
303310

304311

312+
class Bnb8bitDeviceTests(Base8bitTests):
313+
def setUp(self) -> None:
314+
gc.collect()
315+
torch.cuda.empty_cache()
316+
317+
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
318+
self.model_8bit = SanaTransformer2DModel.from_pretrained(
319+
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
320+
subfolder="transformer",
321+
quantization_config=mixed_int8_config,
322+
)
323+
324+
def tearDown(self):
325+
del self.model_8bit
326+
327+
gc.collect()
328+
torch.cuda.empty_cache()
329+
330+
def test_buffers_device_assignment(self):
331+
for buffer_name, buffer in self.model_8bit.named_buffers():
332+
self.assertEqual(
333+
buffer.device.type,
334+
torch.device(torch_device).type,
335+
f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
336+
)
337+
338+
305339
class BnB8bitTrainingTests(Base8bitTests):
306340
def setUp(self):
307341
gc.collect()

0 commit comments

Comments
 (0)