Skip to content

Commit 6cf9a78

Browse files
committed
update
1 parent c29684f commit 6cf9a78

File tree

6 files changed

+36
-17
lines changed

6 files changed

+36
-17
lines changed

docs/source/en/api/quantization.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
3131
## GGUFQuantizationConfig
3232

3333
[[autodoc]] GGUFQuantizationConfig
34+
35+
## QuantoConfig
36+
37+
[[autodoc]] QuantoConfig
38+
3439
## TorchAoConfig
3540

3641
[[autodoc]] TorchAoConfig

docs/source/en/quantization/quanto.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import torch
3333
from diffusers import FluxTransformer2DModel, QuantoConfig
3434

3535
model_id = "black-forest-labs/FLUX.1-dev"
36-
quantization_config = QuantoConfig(weights="float8")
36+
quantization_config = QuantoConfig(weights_dtype="float8")
3737
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
3838

3939
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
@@ -55,7 +55,7 @@ import torch
5555
from diffusers import FluxTransformer2DModel, QuantoConfig
5656

5757
model_id = "black-forest-labs/FLUX.1-dev"
58-
quantization_config = QuantoConfig(weights="float8", modules_to_not_convert=["proj_out"])
58+
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
5959
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
6060
```
6161

@@ -66,7 +66,7 @@ import torch
6666
from diffusers import FluxTransformer2DModel, QuantoConfig
6767

6868
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
69-
quantization_config = QuantoConfig(weights="float8")
69+
quantization_config = QuantoConfig(weights_dtype="float8")
7070
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
7171
```
7272

@@ -79,7 +79,7 @@ import torch
7979
from diffusers import FluxTransformer2DModel, QuantoConfig
8080

8181
model_id = "black-forest-labs/FLUX.1-dev"
82-
quantization_config = QuantoConfig(weights="float8")
82+
quantization_config = QuantoConfig(weights_dtype="float8")
8383
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
8484

8585
# save quantized model to reuse
@@ -100,7 +100,7 @@ import torch
100100
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
101101

102102
model_id = "black-forest-labs/FLUX.1-dev"
103-
quantization_config = QuantoConfig(weights="int8")
103+
quantization_config = QuantoConfig(weights_dtype="int8")
104104
transformer = FluxTransformer2DModel.from_pretrained(
105105
model_id,
106106
subfolder="transformer",

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def run(self):
241241

242242
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
243243
extras["gguf"] = deps_list("gguf", "accelerate")
244-
extras["quanto"] = deps_list("quanto", "accelerate")
244+
extras["quanto"] = deps_list("optimum_quanto", "accelerate")
245245
extras["torchao"] = deps_list("torchao", "accelerate")
246246

247247
if os.name == "nt": # windows

src/diffusers/models/model_loading_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ def load_model_dict_into_meta(
259259
):
260260
param = param.to(torch.float32)
261261
set_module_kwargs["dtype"] = torch.float32
262+
# For quantizers have save weights using torch.float8_e4m3fn
263+
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
264+
pass
262265
else:
263266
param = param.to(dtype)
264267
set_module_kwargs["dtype"] = dtype
@@ -306,7 +309,9 @@ def load_model_dict_into_meta(
306309
elif is_quantized and (
307310
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
308311
):
309-
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
312+
hf_quantizer.create_quantized_param(
313+
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
314+
)
310315
else:
311316
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
312317

src/diffusers/quantizers/quanto/quanto_quantizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import TYPE_CHECKING, Any, Dict, List, Union
22

3+
from diffusers.utils.import_utils import is_optimum_quanto_version
4+
35
from ...utils import (
46
get_module_from_name,
57
is_accelerate_available,
@@ -44,6 +46,12 @@ def validate_environment(self, *args, **kwargs):
4446
raise ImportError(
4547
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
4648
)
49+
if not is_optimum_quanto_version(">=", "0.2.6"):
50+
raise ImportError(
51+
"Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. "
52+
"Please upgrade your installation with `pip install --upgrade optimum-quanto"
53+
)
54+
4755
if not is_accelerate_available():
4856
raise ImportError(
4957
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"

tests/quantization/quanto/test_quanto.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1-
import tempfile
21
import gc
2+
import tempfile
33
import unittest
44

5-
import torch
6-
7-
from diffusers import QuantoConfig, FluxTransformer2DModel, FluxPipeline
8-
from diffusers.utils import is_torch_available, is_optimum_quanto_available
5+
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
6+
from diffusers.models.attention_processor import Attention
7+
from diffusers.utils import is_optimum_quanto_available, is_torch_available
98
from diffusers.utils.testing_utils import (
109
nightly,
1110
numpy_cosine_similarity_distance,
1211
require_accelerate,
1312
require_big_gpu_with_torch_cuda,
1413
torch_device,
1514
)
16-
from diffusers.models.attention_processor import Attention
15+
1716

1817
if is_optimum_quanto_available():
1918
from optimum.quanto import QLinear
@@ -192,7 +191,11 @@ def test_torch_compile(self):
192191
with torch.no_grad():
193192
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
194193

195-
assert torch.allclose(model_output, compiled_model_output, rtol=1e-2, atol=1e-3)
194+
model_output = model_output.detach().float().cpu().numpy()
195+
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
196+
197+
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
198+
assert max_diff < 1e-3
196199

197200

198201
class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
@@ -275,7 +278,7 @@ def test_model_cpu_offload(self):
275278
"hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16
276279
)
277280
pipe.enable_model_cpu_offload(device=torch_device)
278-
images = pipe("a cat holding a sign that says hello", num_inference_steps=2)
281+
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
279282

280283
def test_training(self):
281284
quantization_config = QuantoConfig(**self.get_dummy_init_kwargs())
@@ -311,7 +314,6 @@ def test_training(self):
311314

312315
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
313316
expected_memory_reduction = 0.3
314-
_test_torch_compile = True
315317

316318
def get_dummy_init_kwargs(self):
317319
return {"weights_dtype": "float8"}
@@ -341,7 +343,6 @@ def get_dummy_init_kwargs(self):
341343

342344
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
343345
expected_memory_reduction = 0.55
344-
_test_torch_compile = True
345346

346347
def get_dummy_init_kwargs(self):
347348
return {"weights_dtype": "int4"}

0 commit comments

Comments
 (0)