Skip to content

Commit f4c14c2

Browse files
committed
update
1 parent f52050a commit f4c14c2

File tree

8 files changed

+221
-28
lines changed

8 files changed

+221
-28
lines changed

docs/source/en/quantization/quanto.md

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
1313

1414
# Quanto
1515

16-
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum.](https://huggingface.co/docs/optimum/en/index)
16+
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum.](https://huggingface.co/docs/optimum/en/index)
1717
It has been designed with versatility and simplicity in mind:
1818

1919
- All features are available in eager mode (works with non-traceable models)
@@ -27,10 +27,10 @@ In order to use the Quanto backend, you will first need to install `optimum-quan
2727
pip install optimum-quanto accelerate
2828
```
2929

30-
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. The following snippet demonstrates how to apply `float8` quantization with Quanto.
30+
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. The following snippet demonstrates how to apply `float8` quantization with Quanto.
3131

3232
```python
33-
import torch
33+
import torch
3434
from diffusers import FluxTransformer2DModel, QuantoConfig
3535

3636
model_id = "black-forest-labs/FLUX.1-dev"
@@ -46,24 +46,57 @@ image = pipe(
4646
).images[0]
4747
image.save("output.png")
4848
```
49-
## Saving Quantized models
5049

51-
Diffusers supports serializing and saving Quanto models using the `save_pretrained` method.
50+
## Using `from_single_file` with the Quanto Backend
51+
5252
```python
53+
import torch
54+
from diffusers import FluxTransformer2DModel, QuantoConfig
55+
56+
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
57+
quantization_config = QuantoConfig(weights="float8")
58+
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
59+
```
60+
61+
## Saving Quantized models
62+
63+
Diffusers supports serializing and saving Quanto models using the `save_pretrained` method.
5364

54-
import torch
65+
```python
66+
import torch
5567
from diffusers import FluxTransformer2DModel, QuantoConfig
5668

5769
model_id = "black-forest-labs/FLUX.1-dev"
5870
quantization_config = QuantoConfig(weights="float8")
5971
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
6072

6173
# save quantized model to reuse
62-
transformer.save_pretrained("<your save path>")
74+
transformer.save_pretrained("<your quantized model save path>")
75+
76+
# you can reload your quantized model with
77+
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
78+
```
79+
80+
## Using `torch.compile` with Quanto
81+
82+
Currently the Quanto backend only supports `torch.compile` for `int8` weights and activations.
83+
84+
```python
85+
import torch
86+
from diffusers import FluxTransformer2DModel, QuantoConfig
87+
88+
model_id = "black-forest-labs/FLUX.1-dev"
89+
quantization_config = QuantoConfig(weights="int8")
90+
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
91+
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
92+
93+
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
94+
pipe.to("cuda")
95+
```
6396

6497
## Supported Quantization Types
6598

66-
### Weights
99+
### Weights
67100

68101
- float8
69102
- int8
@@ -73,15 +106,3 @@ transformer.save_pretrained("<your save path>")
73106
### Activations
74107
- float8
75108
- int8
76-
77-
78-
```
79-
```
80-
```
81-
82-
83-
```
84-
85-
86-
87-

src/diffusers/__init__.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
from typing import TYPE_CHECKING
44

5+
from diffusers.quantizers import quantization_config
6+
from diffusers.utils import dummy_gguf_objects
7+
from diffusers.utils.import_utils import (
8+
is_bitsandbytes_available,
9+
is_gguf_available,
10+
is_optimum_quanto_version,
11+
is_torchao_available,
12+
)
13+
514
from .utils import (
615
DIFFUSERS_SLOW_IMPORT,
716
OptionalDependencyNotAvailable,
@@ -33,12 +42,7 @@
3342
"loaders": ["FromOriginalModelMixin"],
3443
"models": [],
3544
"pipelines": [],
36-
"quantizers.quantization_config": [
37-
"BitsAndBytesConfig",
38-
"GGUFQuantizationConfig",
39-
"QuantoConfig",
40-
"TorchAoConfig",
41-
],
45+
"quantizers.quantization_config": [],
4246
"schedulers": [],
4347
"utils": [
4448
"OptionalDependencyNotAvailable",
@@ -73,6 +77,56 @@
7377
else:
7478
_import_structure["quantizers.quantization_config"].extend("QuantoConfig")
7579
"""
80+
81+
try:
82+
if not is_bitsandbytes_available():
83+
raise OptionalDependencyNotAvailable()
84+
except OptionalDependencyNotAvailable:
85+
from .utils import dummy_bitsandbytes_objects
86+
87+
_import_structure["utils.dummy_bitsandbytes_objects"] = [
88+
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
89+
]
90+
else:
91+
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
92+
93+
try:
94+
if not is_gguf_available():
95+
raise OptionalDependencyNotAvailable()
96+
except OptionalDependencyNotAvailable:
97+
from .utils import dummy_gguf_objects
98+
99+
_import_structure["utils.dummy_gguf_objects"] = [
100+
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
101+
]
102+
else:
103+
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
104+
105+
try:
106+
if not is_torchao_available():
107+
raise OptionalDependencyNotAvailable()
108+
except OptionalDependencyNotAvailable:
109+
from .utils import dummy_torchao_objects
110+
111+
_import_structure["utils.dummy_torchao_bjects"] = [
112+
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
113+
]
114+
else:
115+
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
116+
117+
try:
118+
if not is_optimum_quanto_available():
119+
raise OptionalDependencyNotAvailable()
120+
except OptionalDependencyNotAvailable:
121+
from utils import dummy_optimum_quanto_objects
122+
123+
_import_structure["utils.dummy_optimum_quanto_objects"] = [
124+
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
125+
]
126+
else:
127+
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
128+
129+
76130
try:
77131
if not is_onnx_available():
78132
raise OptionalDependencyNotAvailable()
@@ -600,7 +654,38 @@
600654

601655
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
602656
from .configuration_utils import ConfigMixin
603-
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, QuantoConfig, TorchAoConfig
657+
658+
try:
659+
if not is_bitsandbytes_available():
660+
raise OptionalDependencyNotAvailable()
661+
except OptionalDependencyNotAvailable:
662+
from .utils.dummy_bitsandbytes_objects import *
663+
else:
664+
from .quantizers.quantization_config import BitsAndBytesConfig
665+
666+
try:
667+
if not is_gguf_available():
668+
raise OptionalDependencyNotAvailable()
669+
except OptionalDependencyNotAvailable:
670+
from .utils.dummy_gguf_objects import *
671+
else:
672+
from .quantizers.quantization_config import GGUFQuantizationConfig
673+
674+
try:
675+
if not is_torchao_available():
676+
raise OptionalDependencyNotAvailable()
677+
except OptionalDependencyNotAvailable:
678+
from .utils.dummy_torchao_objects import *
679+
else:
680+
from .quantizers.quantization_config import TorchAoConfig
681+
682+
try:
683+
if not is_optimum_quanto_available():
684+
raise OptionalDependencyNotAvailable()
685+
except OptionalDependencyNotAvailable:
686+
from .utils.dummy_optimum_quanto_objects import *
687+
else:
688+
from .quantizers.quantization_config import QuantoConfig
604689

605690
try:
606691
if not is_onnx_available():

src/diffusers/models/modeling_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10411041
model,
10421042
state_dict,
10431043
device=param_device,
1044-
dtype=torch_dtype,
10451044
model_name_or_path=pretrained_model_name_or_path,
10461045
hf_quantizer=hf_quantizer,
10471046
keep_in_fp32_modules=keep_in_fp32_modules,
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# This file is autogenerated by the command `make fix-copies`, do not edit.
2+
from ..utils import DummyObject, requires_backends
3+
4+
5+
class BitsAndBytesConfig(metaclass=DummyObject):
6+
_backends = ["bitsandbytes"]
7+
8+
def __init__(self, *args, **kwargs):
9+
requires_backends(self, ["bitsandbytes"])
10+
11+
@classmethod
12+
def from_config(cls, *args, **kwargs):
13+
requires_backends(cls, ["bitsandbytes"])
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs):
17+
requires_backends(cls, ["bitsandbytes"])
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# This file is autogenerated by the command `make fix-copies`, do not edit.
2+
from ..utils import DummyObject, requires_backends
3+
4+
5+
class GGUFQuantizationConfig(metaclass=DummyObject):
6+
_backends = ["gguf"]
7+
8+
def __init__(self, *args, **kwargs):
9+
requires_backends(self, ["gguf"])
10+
11+
@classmethod
12+
def from_config(cls, *args, **kwargs):
13+
requires_backends(cls, ["gguf"])
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs):
17+
requires_backends(cls, ["gguf"])
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# This file is autogenerated by the command `make fix-copies`, do not edit.
2+
from ..utils import DummyObject, requires_backends
3+
4+
5+
class QuantoConfig(metaclass=DummyObject):
6+
_backends = ["optimum_quanto"]
7+
8+
def __init__(self, *args, **kwargs):
9+
requires_backends(self, ["optimum_quanto"])
10+
11+
@classmethod
12+
def from_config(cls, *args, **kwargs):
13+
requires_backends(cls, ["optimum_quanto"])
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs):
17+
requires_backends(cls, ["optimum_quanto"])
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# This file is autogenerated by the command `make fix-copies`, do not edit.
2+
from ..utils import DummyObject, requires_backends
3+
4+
5+
class TorchAoConfig(metaclass=DummyObject):
6+
_backends = ["torchao"]
7+
8+
def __init__(self, *args, **kwargs):
9+
requires_backends(self, ["torchao"])
10+
11+
@classmethod
12+
def from_config(cls, *args, **kwargs):
13+
requires_backends(cls, ["torchao"])
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs):
17+
requires_backends(cls, ["torchao"])

tests/quantization/quanto/test_quanto.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from diffusers.utils import is_optimum_quanto_available
1010
from diffusers.utils.testing_utils import (
1111
nightly,
12+
numpy_cosine_similarity_distance,
1213
require_accelerate,
1314
require_big_gpu_with_torch_cuda,
1415
torch_device,
@@ -142,6 +143,25 @@ class FluxTransformerInt8(FluxTransformerQuantoMixin, unittest.TestCase):
142143
def get_dummy_init_kwargs(self):
143144
return {"weights": "int8"}
144145

146+
def test_torch_compile(self):
147+
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
148+
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True)
149+
inputs = self.get_dummy_inputs()
150+
151+
model.to(torch_device)
152+
with torch.no_grad():
153+
model_output = model(**inputs).sample
154+
model.to("cpu")
155+
156+
compiled_model.to(torch_device)
157+
with torch.no_grad():
158+
compiled_model_output = compiled_model(**inputs).sample
159+
160+
max_diff = numpy_cosine_similarity_distance(
161+
model_output.cpu().flatten(), compiled_model_output.cpu().flatten()
162+
)
163+
assert max_diff < 1e-4
164+
145165

146166
class FluxTransformerInt4(FluxTransformerQuantoMixin, unittest.TestCase):
147167
expected_memory_use_in_gb = 6

0 commit comments

Comments
 (0)