Skip to content

Commit 44ed55e

Browse files
stevhliusayakpaul
andauthored
[docs] AOBaseConfig (#12302)
init Co-authored-by: Sayak Paul <[email protected]>
1 parent 3cebb5f commit 44ed55e

File tree

1 file changed

+66
-39
lines changed

1 file changed

+66
-39
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,69 +11,96 @@ specific language governing permissions and limitations under the License. -->
1111

1212
# torchao
1313

14-
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
14+
[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
1515

16-
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
16+
Make sure Pytorch 2.5+ and torchao are installed with the command below.
1717

1818
```bash
19-
pip install -U torch torchao
19+
uv pip install -U torch torchao
2020
```
2121

22+
Each quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments.
2223

23-
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
24+
Pass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`].
2425

25-
The example below only quantizes the weights to int8.
26-
27-
```python
26+
```py
2827
import torch
29-
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
30-
31-
model_id = "black-forest-labs/FLUX.1-dev"
32-
dtype = torch.bfloat16
28+
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
29+
from torchao.quantization import Int8WeightOnlyConfig
3330

34-
quantization_config = TorchAoConfig("int8wo")
35-
transformer = AutoModel.from_pretrained(
36-
model_id,
37-
subfolder="transformer",
38-
quantization_config=quantization_config,
39-
torch_dtype=dtype,
31+
pipeline_quant_config = PipelineQuantizationConfig(
32+
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
4033
)
41-
pipe = FluxPipeline.from_pretrained(
42-
model_id,
43-
transformer=transformer,
44-
torch_dtype=dtype,
34+
pipeline = DiffusionPipeline.from_pretrained(
35+
"black-forest-labs/FLUX.1-dev",
36+
quantzation_config=pipeline_quant_config,
37+
torch_dtype=torch.bfloat16,
38+
device_map="cuda"
4539
)
46-
pipe.to("cuda")
40+
```
4741

48-
# Without quantization: ~31.447 GB
49-
# With quantization: ~20.40 GB
50-
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
42+
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
5143

52-
prompt = "A cat holding a sign that says hello world"
53-
image = pipe(
54-
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
55-
).images[0]
56-
image.save("output.png")
44+
```py
45+
import torch
46+
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
47+
48+
pipeline_quant_config = PipelineQuantizationConfig(
49+
quant_mapping={"transformer": TorchAoConfig("int8wo")}
50+
)
51+
pipeline = DiffusionPipeline.from_pretrained(
52+
"black-forest-labs/FLUX.1-dev",
53+
quantzation_config=pipeline_quant_config,
54+
torch_dtype=torch.bfloat16,
55+
device_map="cuda"
56+
)
5757
```
5858

59-
TorchAO is fully compatible with [torch.compile](../optimization/fp16#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
59+
## torch.compile
60+
61+
torchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code.
6062

6163
```python
62-
# In the above code, add the following after initializing the transformer
63-
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
64+
import torch
65+
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
66+
from torchao.quantization import Int4WeightOnlyConfig
67+
68+
pipeline_quant_config = PipelineQuantizationConfig(
69+
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
70+
)
71+
pipeline = DiffusionPipeline.from_pretrained(
72+
"black-forest-labs/FLUX.1-dev",
73+
quantzation_config=pipeline_quant_config,
74+
torch_dtype=torch.bfloat16,
75+
device_map="cuda"
76+
)
77+
78+
pipeline.transformer.compile(transformer, mode="max-autotune", fullgraph=True)
6479
```
6580

66-
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
81+
Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
6782

6883
> [!TIP]
6984
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
7085
71-
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
86+
## autoquant
87+
88+
torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
89+
90+
```py
91+
import torch
92+
from diffusers import DiffusionPipeline
93+
from torchao.quantization import autoquant
94+
95+
# Load the pipeline
96+
pipeline = DiffusionPipeline.from_pretrained(
97+
"black-forest-labs/FLUX.1-schnell",
98+
torch_dtype=torch.bfloat16,
99+
device_map="cuda"
100+
)
72101

73-
The `TorchAoConfig` class accepts three parameters:
74-
- `quant_type`: A string value mentioning one of the quantization types below.
75-
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
76-
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
102+
transformer = autoquant(pipeline.transformer)
103+
```
77104

78105
## Supported quantization types
79106

0 commit comments

Comments
 (0)