Skip to content

Commit c4b6e24

Browse files
committed
update
1 parent 79901e4 commit c4b6e24

File tree

6 files changed

+181
-48
lines changed

6 files changed

+181
-48
lines changed

docs/source/en/quantization/quanto.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,30 @@ model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>
9191

9292
## Using `torch.compile` with Quanto
9393

94-
Currently the Quanto backend only supports `torch.compile` for `int8` weights and activations.
94+
Currently the Quanto backend supports `torch.compile` for the following quantization types:
95+
96+
- `int8` weights
9597

9698
```python
9799
import torch
98-
from diffusers import FluxTransformer2DModel, QuantoConfig
100+
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
99101

100102
model_id = "black-forest-labs/FLUX.1-dev"
101103
quantization_config = QuantoConfig(weights="int8")
102-
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
104+
transformer = FluxTransformer2DModel.from_pretrained(
105+
model_id,
106+
subfolder="transformer",
107+
quantization_config=quantization_config,
108+
torch_dtype=torch.bfloat16,
109+
)
103110
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
104111

105-
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
112+
pipe = FluxPipeline.from_pretrained(
113+
model_id, transformer=transformer, torch_dtype=torch_dtype
114+
)
106115
pipe.to("cuda")
116+
images = pipe("A cat holding a sign that says hello").images[0]
117+
images.save("flux-quanto.png")
107118
```
108119

109120
## Supported Quantization Types

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ def run(self):
236236
)
237237
extras["torch"] = deps_list("torch", "accelerate")
238238

239+
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
240+
extras["gguf"] = deps_list("gguf", "accelerate")
241+
extras["quanto"] = deps_list("quanto", "accelerate")
242+
extras["torchao"] = deps_list("torchao", "accelerate")
243+
239244
if os.name == "nt": # windows
240245
extras["flax"] = [] # jax is not supported on windows
241246
else:

src/diffusers/quantizers/auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@
3737
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
3838
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
3939
"gguf": GGUFQuantizer,
40-
"torchao": TorchAoHfQuantizer,
4140
"quanto": QuantoQuantizer,
41+
"torchao": TorchAoHfQuantizer,
4242
}
4343

4444
AUTO_QUANTIZATION_CONFIG_MAPPING = {
4545
"bitsandbytes_4bit": BitsAndBytesConfig,
4646
"bitsandbytes_8bit": BitsAndBytesConfig,
4747
"gguf": GGUFQuantizationConfig,
48-
"torchao": TorchAoConfig,
4948
"quanto": QuantoConfig,
49+
"torchao": TorchAoConfig,
5050
}
5151

5252

src/diffusers/quantizers/quantization_config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -695,14 +695,14 @@ class QuantoConfig(QuantizationConfigMixin):
695695

696696
def __init__(
697697
self,
698-
weights="int8",
699-
activations=None,
698+
weights_dtype="int8",
699+
activations_dtype=None,
700700
modules_to_not_convert: Optional[List] = None,
701701
**kwargs,
702702
):
703703
self.quant_method = QuantizationMethod.QUANTO
704-
self.weights = weights
705-
self.activations = activations
704+
self.weights_dtype = weights_dtype
705+
self.activations_dtype = activations_dtype
706706
self.modules_to_not_convert = modules_to_not_convert
707707

708708
self.post_init()
@@ -713,8 +713,8 @@ def post_init(self):
713713
"""
714714
accepted_weights = ["float8", "int8", "int4", "int2"]
715715
accepted_activations = [None, "int8", "float8"]
716-
if self.weights not in accepted_weights:
717-
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
716+
if self.weights_dtype not in accepted_weights:
717+
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
718718

719-
if self.activations not in accepted_activations:
720-
raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}")
719+
if self.activations_dtype not in accepted_activations:
720+
raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations_dtype}")

src/diffusers/quantizers/quanto/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def _replace_layers(model, quantization_config, modules_to_not_convert):
3939
out_features=module.out_features,
4040
bias=module.bias is not None,
4141
dtype=module.weight.dtype,
42-
weights=_get_weight_type(quantization_config.weights),
43-
activations=_get_activation_type(quantization_config.activations),
42+
weights=_get_weight_type(quantization_config.weights_dtype),
43+
activations=_get_activation_type(quantization_config.activations_dtype),
4444
)
4545
model._modules[name] = qlinear
4646
model._modules[name].source_cls = type(module)

0 commit comments

Comments
 (0)