Skip to content

Commit fbcdf8b

Browse files
authored
Merge branch 'main' into cache-non-lora-outputs
2 parents ead2e04 + 64a5187 commit fbcdf8b

File tree

4 files changed

+295
-89
lines changed

4 files changed

+295
-89
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,69 +11,96 @@ specific language governing permissions and limitations under the License. -->
1111

1212
# torchao
1313

14-
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
14+
[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
1515

16-
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
16+
Make sure Pytorch 2.5+ and torchao are installed with the command below.
1717

1818
```bash
19-
pip install -U torch torchao
19+
uv pip install -U torch torchao
2020
```
2121

22+
Each quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments.
2223

23-
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
24+
Pass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`].
2425

25-
The example below only quantizes the weights to int8.
26-
27-
```python
26+
```py
2827
import torch
29-
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
30-
31-
model_id = "black-forest-labs/FLUX.1-dev"
32-
dtype = torch.bfloat16
28+
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
29+
from torchao.quantization import Int8WeightOnlyConfig
3330

34-
quantization_config = TorchAoConfig("int8wo")
35-
transformer = AutoModel.from_pretrained(
36-
model_id,
37-
subfolder="transformer",
38-
quantization_config=quantization_config,
39-
torch_dtype=dtype,
31+
pipeline_quant_config = PipelineQuantizationConfig(
32+
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
4033
)
41-
pipe = FluxPipeline.from_pretrained(
42-
model_id,
43-
transformer=transformer,
44-
torch_dtype=dtype,
34+
pipeline = DiffusionPipeline.from_pretrained(
35+
"black-forest-labs/FLUX.1-dev",
36+
quantzation_config=pipeline_quant_config,
37+
torch_dtype=torch.bfloat16,
38+
device_map="cuda"
4539
)
46-
pipe.to("cuda")
40+
```
4741

48-
# Without quantization: ~31.447 GB
49-
# With quantization: ~20.40 GB
50-
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
42+
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
5143

52-
prompt = "A cat holding a sign that says hello world"
53-
image = pipe(
54-
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
55-
).images[0]
56-
image.save("output.png")
44+
```py
45+
import torch
46+
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
47+
48+
pipeline_quant_config = PipelineQuantizationConfig(
49+
quant_mapping={"transformer": TorchAoConfig("int8wo")}
50+
)
51+
pipeline = DiffusionPipeline.from_pretrained(
52+
"black-forest-labs/FLUX.1-dev",
53+
quantzation_config=pipeline_quant_config,
54+
torch_dtype=torch.bfloat16,
55+
device_map="cuda"
56+
)
5757
```
5858

59-
TorchAO is fully compatible with [torch.compile](../optimization/fp16#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
59+
## torch.compile
60+
61+
torchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code.
6062

6163
```python
62-
# In the above code, add the following after initializing the transformer
63-
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
64+
import torch
65+
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
66+
from torchao.quantization import Int4WeightOnlyConfig
67+
68+
pipeline_quant_config = PipelineQuantizationConfig(
69+
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
70+
)
71+
pipeline = DiffusionPipeline.from_pretrained(
72+
"black-forest-labs/FLUX.1-dev",
73+
quantzation_config=pipeline_quant_config,
74+
torch_dtype=torch.bfloat16,
75+
device_map="cuda"
76+
)
77+
78+
pipeline.transformer.compile(transformer, mode="max-autotune", fullgraph=True)
6479
```
6580

66-
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
81+
Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
6782

6883
> [!TIP]
6984
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
7085
71-
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
86+
## autoquant
87+
88+
torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
89+
90+
```py
91+
import torch
92+
from diffusers import DiffusionPipeline
93+
from torchao.quantization import autoquant
94+
95+
# Load the pipeline
96+
pipeline = DiffusionPipeline.from_pretrained(
97+
"black-forest-labs/FLUX.1-schnell",
98+
torch_dtype=torch.bfloat16,
99+
device_map="cuda"
100+
)
72101

73-
The `TorchAoConfig` class accepts three parameters:
74-
- `quant_type`: A string value mentioning one of the quantization types below.
75-
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
76-
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
102+
transformer = autoquant(pipeline.transformer)
103+
```
77104

78105
## Supported quantization types
79106

src/diffusers/quantizers/quantization_config.py

Lines changed: 136 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,20 @@
2121
"""
2222

2323
import copy
24+
import dataclasses
2425
import importlib.metadata
2526
import inspect
2627
import json
2728
import os
2829
import warnings
29-
from dataclasses import dataclass
30+
from dataclasses import dataclass, is_dataclass
3031
from enum import Enum
3132
from functools import partial
3233
from typing import Any, Callable, Dict, List, Optional, Union
3334

3435
from packaging import version
3536

36-
from ..utils import is_torch_available, is_torchao_available, logging
37+
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
3738

3839

3940
if is_torch_available():
@@ -443,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin):
443444
"""This is a config class for torchao quantization/sparsity techniques.
444445
445446
Args:
446-
quant_type (`str`):
447+
quant_type (Union[`str`, AOBaseConfig]):
447448
The type of quantization we want to use, currently supporting:
448449
- **Integer quantization:**
449450
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
@@ -465,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin):
465466
- **Unsigned Integer quantization:**
466467
- Full function names: `uintx_weight_only`
467468
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
469+
- An AOBaseConfig instance: for more advanced configuration options.
468470
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
469471
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
470472
modules left in their original precision.
@@ -478,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin):
478480
```python
479481
from diffusers import FluxTransformer2DModel, TorchAoConfig
480482
483+
# AOBaseConfig-based configuration
484+
from torchao.quantization import Int8WeightOnlyConfig
485+
486+
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
487+
488+
# String-based config
481489
quantization_config = TorchAoConfig("int8wo")
482490
transformer = FluxTransformer2DModel.from_pretrained(
483491
"black-forest-labs/Flux.1-Dev",
@@ -490,7 +498,7 @@ class TorchAoConfig(QuantizationConfigMixin):
490498

491499
def __init__(
492500
self,
493-
quant_type: str,
501+
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
494502
modules_to_not_convert: Optional[List[str]] = None,
495503
**kwargs,
496504
) -> None:
@@ -504,34 +512,103 @@ def __init__(
504512
else:
505513
self.quant_type_kwargs = kwargs
506514

507-
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
508-
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
509-
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
510-
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
515+
self.post_init()
516+
517+
def post_init(self):
518+
if not isinstance(self.quant_type, str):
519+
if is_torchao_version("<=", "0.9.0"):
511520
raise ValueError(
512-
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
513-
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
521+
f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
522+
f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
514523
)
515524

516-
raise ValueError(
517-
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
518-
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
519-
)
525+
from torchao.quantization.quant_api import AOBaseConfig
520526

521-
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
522-
signature = inspect.signature(method)
523-
all_kwargs = {
524-
param.name
525-
for param in signature.parameters.values()
526-
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
527-
}
528-
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
527+
if not isinstance(self.quant_type, AOBaseConfig):
528+
raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
529529

530-
if len(unsupported_kwargs) > 0:
531-
raise ValueError(
532-
f'The quantization method "{quant_type}" does not support the following keyword arguments: '
533-
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
534-
)
530+
elif isinstance(self.quant_type, str):
531+
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
532+
533+
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
534+
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
535+
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
536+
raise ValueError(
537+
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
538+
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
539+
)
540+
541+
raise ValueError(
542+
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
543+
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
544+
)
545+
546+
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
547+
signature = inspect.signature(method)
548+
all_kwargs = {
549+
param.name
550+
for param in signature.parameters.values()
551+
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
552+
}
553+
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
554+
555+
if len(unsupported_kwargs) > 0:
556+
raise ValueError(
557+
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
558+
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
559+
)
560+
561+
def to_dict(self):
562+
"""Convert configuration to a dictionary."""
563+
d = super().to_dict()
564+
565+
if isinstance(self.quant_type, str):
566+
# Handle layout serialization if present
567+
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
568+
if is_dataclass(d["quant_type_kwargs"]["layout"]):
569+
d["quant_type_kwargs"]["layout"] = [
570+
d["quant_type_kwargs"]["layout"].__class__.__name__,
571+
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
572+
]
573+
if isinstance(d["quant_type_kwargs"]["layout"], list):
574+
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
575+
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
576+
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
577+
else:
578+
raise ValueError("layout must be a list")
579+
else:
580+
# Handle AOBaseConfig serialization
581+
from torchao.core.config import config_to_dict
582+
583+
# For now we assume there is 1 config per Transformer, however in the future
584+
# We may want to support a config per fqn.
585+
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
586+
587+
return d
588+
589+
@classmethod
590+
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
591+
"""Create configuration from a dictionary."""
592+
if not is_torchao_version(">", "0.9.0"):
593+
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
594+
config_dict = config_dict.copy()
595+
quant_type = config_dict.pop("quant_type")
596+
597+
if isinstance(quant_type, str):
598+
return cls(quant_type=quant_type, **config_dict)
599+
# Check if we only have one key which is "default"
600+
# In the future we may update this
601+
assert len(quant_type) == 1 and "default" in quant_type, (
602+
"Expected only one key 'default' in quant_type dictionary"
603+
)
604+
quant_type = quant_type["default"]
605+
606+
# Deserialize quant_type if needed
607+
from torchao.core.config import config_from_dict
608+
609+
quant_type = config_from_dict(quant_type)
610+
611+
return cls(quant_type=quant_type, **config_dict)
535612

536613
@classmethod
537614
def _get_torchao_quant_type_to_method(cls):
@@ -681,8 +758,38 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
681758
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
682759

683760
def get_apply_tensor_subclass(self):
684-
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
685-
return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
761+
"""Create the appropriate quantization method based on configuration."""
762+
if not isinstance(self.quant_type, str):
763+
return self.quant_type
764+
else:
765+
methods = self._get_torchao_quant_type_to_method()
766+
quant_type_kwargs = self.quant_type_kwargs.copy()
767+
if (
768+
not torch.cuda.is_available()
769+
and is_torchao_available()
770+
and self.quant_type == "int4_weight_only"
771+
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
772+
and quant_type_kwargs.get("layout", None) is None
773+
):
774+
if torch.xpu.is_available():
775+
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
776+
"0.11.0"
777+
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
778+
from torchao.dtypes import Int4XPULayout
779+
from torchao.quantization.quant_primitives import ZeroPointDomain
780+
781+
quant_type_kwargs["layout"] = Int4XPULayout()
782+
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
783+
else:
784+
raise ValueError(
785+
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
786+
)
787+
else:
788+
from torchao.dtypes import Int4CPULayout
789+
790+
quant_type_kwargs["layout"] = Int4CPULayout()
791+
792+
return methods[self.quant_type](**quant_type_kwargs)
686793

687794
def __repr__(self):
688795
r"""

0 commit comments

Comments
 (0)