diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 5c7578dcbb4e..18cc109e0785 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -11,69 +11,96 @@ specific language governing permissions and limitations under the License. --> # torchao -[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more. +[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. -Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed. +Make sure Pytorch 2.5+ and torchao are installed with the command below. ```bash -pip install -U torch torchao +uv pip install -U torch torchao ``` +Each quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments. -Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. +Pass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`]. -The example below only quantizes the weights to int8. - -```python +```py import torch -from diffusers import FluxPipeline, AutoModel, TorchAoConfig - -model_id = "black-forest-labs/FLUX.1-dev" -dtype = torch.bfloat16 +from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig +from torchao.quantization import Int8WeightOnlyConfig -quantization_config = TorchAoConfig("int8wo") -transformer = AutoModel.from_pretrained( - model_id, - subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=dtype, +pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))} ) -pipe = FluxPipeline.from_pretrained( - model_id, - transformer=transformer, - torch_dtype=dtype, +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantzation_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, + device_map="cuda" ) -pipe.to("cuda") +``` -# Without quantization: ~31.447 GB -# With quantization: ~20.40 GB -print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") +For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below. -prompt = "A cat holding a sign that says hello world" -image = pipe( - prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 -).images[0] -image.save("output.png") +```py +import torch +from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig + +pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={"transformer": TorchAoConfig("int8wo")} +) +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantzation_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, + device_map="cuda" +) ``` -TorchAO is fully compatible with [torch.compile](../optimization/fp16#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code. +## torch.compile + +torchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code. ```python -# In the above code, add the following after initializing the transformer -transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) +import torch +from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig +from torchao.quantization import Int4WeightOnlyConfig + +pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))} +) +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantzation_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, + device_map="cuda" +) + +pipeline.transformer.compile(transformer, mode="max-autotune", fullgraph=True) ``` -For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware. +Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks). > [!TIP] > The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible. -torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future. +## autoquant + +torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment. + +```py +import torch +from diffusers import DiffusionPipeline +from torchao.quantization import autoquant + +# Load the pipeline +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16, + device_map="cuda" +) -The `TorchAoConfig` class accepts three parameters: -- `quant_type`: A string value mentioning one of the quantization types below. -- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`. -- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`. +transformer = autoquant(pipeline.transformer) +``` ## Supported quantization types