Skip to content

Commit abd3a91

Browse files
committed
add tests
1 parent 2753abe commit abd3a91

File tree

6 files changed

+303
-74
lines changed

6 files changed

+303
-74
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,8 @@ def load_model_dict_into_meta(
239239
# in int/uint/bool and not cast them.
240240
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
241241
if dtype is not None and torch.is_floating_point(param):
242-
if (
243-
keep_in_fp32_modules is not None
244-
and any(
245-
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
246-
)
247-
and dtype == torch.float16
242+
if keep_in_fp32_modules is not None and any(
243+
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
248244
):
249245
param = param.to(torch.float32)
250246
set_module_kwargs["dtype"] = torch.float32

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,9 +1011,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10111011
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
10121012

10131013
# Check if `_keep_in_fp32_modules` is not None
1014-
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
1015-
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
1014+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) or hasattr(
1015+
hf_quantizer, "use_keep_in_fp32_modules"
10161016
)
1017+
10171018
if use_keep_in_fp32_modules:
10181019
keep_in_fp32_modules = cls._keep_in_fp32_modules
10191020
if not isinstance(keep_in_fp32_modules, list):

tests/models/test_modeling_common.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from parameterized import parameterized
3838
from requests.exceptions import HTTPError
3939

40-
from diffusers.models import UNet2DConditionModel
40+
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
4141
from diffusers.models.attention_processor import (
4242
AttnProcessor,
4343
AttnProcessor2_0,
@@ -334,6 +334,28 @@ def test_weight_overwrite(self):
334334

335335
assert model.config.in_channels == 9
336336

337+
def test_keep_modules_in_fp32(self):
338+
r"""
339+
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
340+
Also ensures if inference works.
341+
"""
342+
fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
343+
344+
for torch_dtype in [torch.bfloat16, torch.float16]:
345+
SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
346+
347+
model = SD3Transformer2DModel.from_pretrained(
348+
"stabilityai/stable-diffusion-3-medium-diffusers", subfolder="transformer", torch_dtype=torch_dtype
349+
)
350+
351+
for name, module in model.named_modules():
352+
if isinstance(module, torch.nn.Linear):
353+
if name in model._keep_in_fp32_modules:
354+
self.assertTrue(module.weight.dtype == torch.float32)
355+
else:
356+
self.assertTrue(module.weight.dtype == torch_dtype)
357+
SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
358+
337359

338360
class UNetTesterMixin:
339361
def test_forward_with_norm_groups(self):

tests/quantization/bnb/test_4bit.py

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def setUp(self):
136136
bnb_4bit_compute_dtype=torch.float16,
137137
)
138138
self.model_4bit = SD3Transformer2DModel.from_pretrained(
139-
self.model_name, subfolder="transformer", quantization_config=nf4_config
139+
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
140140
)
141141

142142
def tearDown(self):
@@ -202,7 +202,7 @@ def test_keep_modules_in_fp32(self):
202202
bnb_4bit_compute_dtype=torch.float16,
203203
)
204204
model = SD3Transformer2DModel.from_pretrained(
205-
self.model_name, subfolder="transformer", quantization_config=nf4_config
205+
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
206206
)
207207

208208
for name, module in model.named_modules():
@@ -327,7 +327,7 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
327327
with tempfile.TemporaryDirectory() as tmpdirname:
328328
nf4_config = BitsAndBytesConfig(load_in_4bit=True)
329329
model_4bit = SD3Transformer2DModel.from_pretrained(
330-
self.model_name, subfolder="transformer", quantization_config=nf4_config
330+
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
331331
)
332332
model_4bit.save_pretrained(tmpdirname)
333333
del model_4bit
@@ -362,7 +362,7 @@ def setUp(self):
362362
bnb_4bit_compute_dtype=torch.float16,
363363
)
364364
self.model_4bit = SD3Transformer2DModel.from_pretrained(
365-
self.model_name, subfolder="transformer", quantization_config=nf4_config
365+
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
366366
)
367367

368368
def test_training(self):
@@ -410,7 +410,7 @@ def setUp(self) -> None:
410410
bnb_4bit_compute_dtype=torch.float16,
411411
)
412412
model_4bit = SD3Transformer2DModel.from_pretrained(
413-
self.model_name, subfolder="transformer", quantization_config=nf4_config
413+
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
414414
)
415415
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
416416
self.model_name, transformer=model_4bit, torch_dtype=torch.float16
@@ -472,7 +472,7 @@ def test_moving_to_cpu_throws_warning(self):
472472
bnb_4bit_compute_dtype=torch.float16,
473473
)
474474
model_4bit = SD3Transformer2DModel.from_pretrained(
475-
self.model_name, subfolder="transformer", quantization_config=nf4_config
475+
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
476476
)
477477

478478
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
@@ -502,6 +502,7 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
502502
subfolder="transformer",
503503
quantization_config=transformer_nf4_config,
504504
torch_dtype=torch.float16,
505+
device_map=torch_device,
505506
)
506507
text_encoder_3_nf4_config = BnbConfig(
507508
load_in_4bit=True,
@@ -513,6 +514,7 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
513514
subfolder="text_encoder_3",
514515
quantization_config=text_encoder_3_nf4_config,
515516
torch_dtype=torch.float16,
517+
device_map=torch_device,
516518
)
517519
# CUDA device placement works.
518520
pipeline_4bit = DiffusionPipeline.from_pretrained(
@@ -527,6 +529,94 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
527529

528530
del pipeline_4bit
529531

532+
def test_device_map(self):
533+
"""
534+
Test if the quantized model is working properly with "auto".
535+
cpu/disk offloading as well doesn't work with bnb.
536+
"""
537+
538+
def get_dummy_tensor_inputs(device=None, seed: int = 0):
539+
batch_size = 1
540+
num_latent_channels = 4
541+
num_image_channels = 3
542+
height = width = 4
543+
sequence_length = 48
544+
embedding_dim = 32
545+
546+
torch.manual_seed(seed)
547+
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
548+
device, dtype=torch.bfloat16
549+
)
550+
torch.manual_seed(seed)
551+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
552+
device, dtype=torch.bfloat16
553+
)
554+
555+
torch.manual_seed(seed)
556+
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
557+
558+
torch.manual_seed(seed)
559+
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
560+
561+
torch.manual_seed(seed)
562+
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
563+
564+
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
565+
566+
return {
567+
"hidden_states": hidden_states,
568+
"encoder_hidden_states": encoder_hidden_states,
569+
"pooled_projections": pooled_prompt_embeds,
570+
"txt_ids": text_ids,
571+
"img_ids": image_ids,
572+
"timestep": timestep,
573+
}
574+
575+
inputs = get_dummy_tensor_inputs(torch_device)
576+
expected_slice = np.array(
577+
[0.47070312, 0.00390625, -0.03662109, -0.19628906, -0.53125, 0.5234375, -0.17089844, -0.59375, 0.578125]
578+
)
579+
580+
# non sharded
581+
quantization_config = BitsAndBytesConfig(
582+
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
583+
)
584+
quantized_model = FluxTransformer2DModel.from_pretrained(
585+
"hf-internal-testing/tiny-flux-pipe",
586+
subfolder="transformer",
587+
quantization_config=quantization_config,
588+
device_map="auto",
589+
torch_dtype=torch.bfloat16,
590+
)
591+
592+
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
593+
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
594+
595+
output = quantized_model(**inputs)[0]
596+
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
597+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
598+
599+
# sharded
600+
601+
quantization_config = BitsAndBytesConfig(
602+
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
603+
)
604+
quantized_model = FluxTransformer2DModel.from_pretrained(
605+
"hf-internal-testing/tiny-flux-sharded",
606+
subfolder="transformer",
607+
quantization_config=quantization_config,
608+
device_map="auto",
609+
torch_dtype=torch.bfloat16,
610+
)
611+
612+
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
613+
self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
614+
615+
output = quantized_model(**inputs)[0]
616+
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
617+
618+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
619+
530620

531621
@require_transformers_version_greater("4.44.0")
532622
class SlowBnb4BitFluxTests(Base4bitTests):
@@ -610,7 +700,10 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa
610700
bnb_4bit_compute_dtype=torch.bfloat16,
611701
)
612702
model_0 = SD3Transformer2DModel.from_pretrained(
613-
self.model_name, subfolder="transformer", quantization_config=self.quantization_config
703+
self.model_name,
704+
subfolder="transformer",
705+
quantization_config=self.quantization_config,
706+
device_map=torch_device,
614707
)
615708
self.assertTrue("_pre_quantization_dtype" in model_0.config)
616709
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)