Skip to content

Commit 0ed31bc

Browse files
committed
update
1 parent afd5d7d commit 0ed31bc

File tree

2 files changed

+126
-4
lines changed

2 files changed

+126
-4
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,14 +1010,14 @@ def to(self, *args, **kwargs):
10101010
dtype_present_in_args = True
10111011
break
10121012

1013-
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
1014-
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1013+
if getattr(self, "is_quantized", False):
10151014
if dtype_present_in_args:
10161015
raise ValueError(
1017-
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
1018-
" desired `dtype` by passing the correct `torch_dtype` argument."
1016+
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
1017+
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
10191018
)
10201019

1020+
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
10211021
if getattr(self, "is_loaded_in_8bit", False):
10221022
raise ValueError(
10231023
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import gc
2+
import unittest
3+
4+
import torch
5+
6+
from diffusers import FluxTransformer2DModel, GGUFQuantizationConfig
7+
from diffusers.quantizers.gguf.utils import GGUFParameter
8+
from diffusers.utils.testing_utils import (
9+
nightly,
10+
require_big_gpu_with_torch_cuda,
11+
torch_device,
12+
)
13+
14+
15+
@nightly
16+
@require_big_gpu_with_torch_cuda
17+
class GGUFSingleFileTests(unittest.TestCase):
18+
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
19+
torch_dtype = torch.bfloat16
20+
21+
def setUp(self):
22+
gc.collect()
23+
torch.cuda.empty_cache()
24+
25+
def tearDown(self):
26+
gc.collect()
27+
torch.cuda.empty_cache()
28+
29+
def get_dummy_inputs(self):
30+
return {
31+
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
32+
torch_device, self.torch_dtype
33+
),
34+
"encoder_hidden_states": torch.randn(
35+
(1, 512, 4096),
36+
generator=torch.Generator("cpu").manual_seed(0),
37+
).to(torch_device, self.torch_dtype),
38+
"pooled_projections": torch.randn(
39+
(1, 768),
40+
generator=torch.Generator("cpu").manual_seed(0),
41+
).to(torch_device, self.torch_dtype),
42+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
43+
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
44+
torch_device, self.torch_dtype
45+
),
46+
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
47+
torch_device, self.torch_dtype
48+
),
49+
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
50+
}
51+
52+
def test_gguf_parameters(self):
53+
quant_storage_type = torch.uint8
54+
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
55+
model = FluxTransformer2DModel.from_single_file(self.ckpt_path, quantization_config=quantization_config)
56+
57+
for param_name, param in model.named_parameters():
58+
if isinstance(param, GGUFParameter):
59+
assert hasattr(param, "quant_type")
60+
assert param.dtype == quant_storage_type
61+
62+
def test_gguf_linear_layers(self):
63+
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
64+
model = FluxTransformer2DModel.from_single_file(self.ckpt_path, quantization_config=quantization_config)
65+
66+
for name, module in model.named_modules():
67+
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
68+
assert module.weight.dtype == torch.uint8
69+
70+
def test_gguf_memory(self):
71+
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
72+
73+
model = FluxTransformer2DModel.from_single_file(
74+
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
75+
)
76+
model.to("cuda")
77+
inputs = self.get_dummy_inputs()
78+
79+
torch.cuda.reset_peak_memory_stats()
80+
torch.cuda.empty_cache()
81+
with torch.no_grad():
82+
model(**inputs)
83+
max_memory = torch.cuda.max_memory_allocated()
84+
assert (max_memory / 1024**3) < 5
85+
86+
def test_keep_modules_in_fp32(self):
87+
r"""
88+
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
89+
Also ensures if inference works.
90+
"""
91+
FluxTransformer2DModel._keep_in_fp32_modules = ["proj_out"]
92+
93+
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
94+
model = FluxTransformer2DModel.from_single_file(self.ckpt_path, quantization_config=quantization_config)
95+
96+
for name, module in model.named_modules():
97+
if isinstance(module, torch.nn.Linear):
98+
if name in model._keep_in_fp32_modules:
99+
assert module.weight.dtype == torch.float32
100+
101+
def test_dtype_assignment(self):
102+
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
103+
model = FluxTransformer2DModel.from_single_file(self.ckpt_path, quantization_config=quantization_config)
104+
105+
with self.assertRaises(ValueError):
106+
# Tries with a `dtype`
107+
model.to(torch.float16)
108+
109+
with self.assertRaises(ValueError):
110+
# Tries with a `device` and `dtype`
111+
model.to(device="cuda:0", dtype=torch.float16)
112+
113+
with self.assertRaises(ValueError):
114+
# Tries with a cast
115+
model.float()
116+
117+
with self.assertRaises(ValueError):
118+
# Tries with a cast
119+
model.half()
120+
121+
# This should work
122+
model.to("cuda")

0 commit comments

Comments
 (0)