Skip to content

Commit 1e699e0

Browse files
authored
Merge branch 'main' into quant-tip
2 parents afc04c0 + 7667cfc commit 1e699e0

File tree

19 files changed

+1495
-28
lines changed

19 files changed

+1495
-28
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@
157157
title: Getting Started
158158
- local: quantization/bitsandbytes
159159
title: bitsandbytes
160+
- local: quantization/torchao
161+
title: torchao
160162
title: Quantization Methods
161163
- sections:
162164
- local: optimization/fp16

docs/source/en/api/attnprocessor.md

Lines changed: 104 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,40 +15,133 @@ specific language governing permissions and limitations under the License.
1515
An attention processor is a class for applying different types of attention mechanisms.
1616

1717
## AttnProcessor
18+
1819
[[autodoc]] models.attention_processor.AttnProcessor
1920

20-
## AttnProcessor2_0
2121
[[autodoc]] models.attention_processor.AttnProcessor2_0
2222

23-
## AttnAddedKVProcessor
2423
[[autodoc]] models.attention_processor.AttnAddedKVProcessor
2524

26-
## AttnAddedKVProcessor2_0
2725
[[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0
2826

27+
[[autodoc]] models.attention_processor.AttnProcessorNPU
28+
29+
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
30+
31+
## Allegro
32+
33+
[[autodoc]] models.attention_processor.AllegroAttnProcessor2_0
34+
35+
## AuraFlow
36+
37+
[[autodoc]] models.attention_processor.AuraFlowAttnProcessor2_0
38+
39+
[[autodoc]] models.attention_processor.FusedAuraFlowAttnProcessor2_0
40+
41+
## CogVideoX
42+
43+
[[autodoc]] models.attention_processor.CogVideoXAttnProcessor2_0
44+
45+
[[autodoc]] models.attention_processor.FusedCogVideoXAttnProcessor2_0
46+
2947
## CrossFrameAttnProcessor
48+
3049
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
3150

32-
## CustomDiffusionAttnProcessor
51+
## Custom Diffusion
52+
3353
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
3454

35-
## CustomDiffusionAttnProcessor2_0
3655
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
3756

38-
## CustomDiffusionXFormersAttnProcessor
3957
[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
4058

41-
## FusedAttnProcessor2_0
42-
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
59+
## Flux
60+
61+
[[autodoc]] models.attention_processor.FluxAttnProcessor2_0
62+
63+
[[autodoc]] models.attention_processor.FusedFluxAttnProcessor2_0
64+
65+
[[autodoc]] models.attention_processor.FluxSingleAttnProcessor2_0
66+
67+
## Hunyuan
68+
69+
[[autodoc]] models.attention_processor.HunyuanAttnProcessor2_0
70+
71+
[[autodoc]] models.attention_processor.FusedHunyuanAttnProcessor2_0
72+
73+
[[autodoc]] models.attention_processor.PAGHunyuanAttnProcessor2_0
74+
75+
[[autodoc]] models.attention_processor.PAGCFGHunyuanAttnProcessor2_0
76+
77+
## IdentitySelfAttnProcessor2_0
78+
79+
[[autodoc]] models.attention_processor.PAGIdentitySelfAttnProcessor2_0
80+
81+
[[autodoc]] models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0
82+
83+
## IP-Adapter
84+
85+
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor
86+
87+
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0
88+
89+
## JointAttnProcessor2_0
90+
91+
[[autodoc]] models.attention_processor.JointAttnProcessor2_0
92+
93+
[[autodoc]] models.attention_processor.PAGJointAttnProcessor2_0
94+
95+
[[autodoc]] models.attention_processor.PAGCFGJointAttnProcessor2_0
96+
97+
[[autodoc]] models.attention_processor.FusedJointAttnProcessor2_0
98+
99+
## LoRA
100+
101+
[[autodoc]] models.attention_processor.LoRAAttnProcessor
102+
103+
[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0
104+
105+
[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
106+
107+
[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
108+
109+
## Lumina-T2X
110+
111+
[[autodoc]] models.attention_processor.LuminaAttnProcessor2_0
112+
113+
## Mochi
114+
115+
[[autodoc]] models.attention_processor.MochiAttnProcessor2_0
116+
117+
[[autodoc]] models.attention_processor.MochiVaeAttnProcessor2_0
118+
119+
## Sana
120+
121+
[[autodoc]] models.attention_processor.SanaLinearAttnProcessor2_0
122+
123+
[[autodoc]] models.attention_processor.SanaMultiscaleAttnProcessor2_0
124+
125+
[[autodoc]] models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0
126+
127+
[[autodoc]] models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0
128+
129+
## Stable Audio
130+
131+
[[autodoc]] models.attention_processor.StableAudioAttnProcessor2_0
43132

44133
## SlicedAttnProcessor
134+
45135
[[autodoc]] models.attention_processor.SlicedAttnProcessor
46136

47-
## SlicedAttnAddedKVProcessor
48137
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
49138

50139
## XFormersAttnProcessor
140+
51141
[[autodoc]] models.attention_processor.XFormersAttnProcessor
52142

53-
## AttnProcessorNPU
54-
[[autodoc]] models.attention_processor.AttnProcessorNPU
143+
[[autodoc]] models.attention_processor.XFormersAttnAddedKVProcessor
144+
145+
## XLAFlashAttnProcessor2_0
146+
147+
[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0

docs/source/en/api/quantization.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
2828

2929
[[autodoc]] BitsAndBytesConfig
3030

31+
## TorchAoConfig
32+
33+
[[autodoc]] TorchAoConfig
34+
3135
## DiffusersQuantizer
3236

3337
[[autodoc]] quantizers.base.DiffusersQuantizer

docs/source/en/quantization/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be
3232

3333
## When to use what?
3434

35-
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
35+
Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# torchao
13+
14+
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
15+
16+
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
17+
18+
```bash
19+
pip install -U torch torchao
20+
```
21+
22+
23+
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
24+
25+
The example below only quantizes the weights to int8.
26+
27+
```python
28+
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
29+
30+
model_id = "black-forest-labs/Flux.1-Dev"
31+
dtype = torch.bfloat16
32+
33+
quantization_config = TorchAoConfig("int8wo")
34+
transformer = FluxTransformer2DModel.from_pretrained(
35+
model_id,
36+
subfolder="transformer",
37+
quantization_config=quantization_config,
38+
torch_dtype=dtype,
39+
)
40+
pipe = FluxPipeline.from_pretrained(
41+
model_id,
42+
transformer=transformer,
43+
torch_dtype=dtype,
44+
)
45+
pipe.to("cuda")
46+
47+
prompt = "A cat holding a sign that says hello world"
48+
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
49+
image.save("output.png")
50+
```
51+
52+
TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
53+
54+
```python
55+
# In the above code, add the following after initializing the transformer
56+
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
57+
```
58+
59+
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
60+
61+
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
62+
63+
The `TorchAoConfig` class accepts three parameters:
64+
- `quant_type`: A string value mentioning one of the quantization types below.
65+
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
66+
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
67+
68+
## Supported quantization types
69+
70+
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
71+
72+
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
73+
74+
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
75+
76+
The quantization methods supported are as follows:
77+
78+
| **Category** | **Full Function Names** | **Shorthands** |
79+
|--------------|-------------------------|----------------|
80+
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
81+
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
82+
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
83+
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
84+
85+
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
86+
87+
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
88+
89+
## Resources
90+
91+
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
92+
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"loaders": ["FromOriginalModelMixin"],
3232
"models": [],
3333
"pipelines": [],
34-
"quantizers.quantization_config": ["BitsAndBytesConfig"],
34+
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"],
3535
"schedulers": [],
3636
"utils": [
3737
"OptionalDependencyNotAvailable",
@@ -569,7 +569,7 @@
569569

570570
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
571571
from .configuration_utils import ConfigMixin
572-
from .quantizers.quantization_config import BitsAndBytesConfig
572+
from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig
573573

574574
try:
575575
if not is_onnx_available():

src/diffusers/models/attention_processor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5423,21 +5423,37 @@ def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Ten
54235423

54245424

54255425
class LoRAAttnProcessor:
5426+
r"""
5427+
Processor for implementing attention with LoRA.
5428+
"""
5429+
54265430
def __init__(self):
54275431
pass
54285432

54295433

54305434
class LoRAAttnProcessor2_0:
5435+
r"""
5436+
Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0).
5437+
"""
5438+
54315439
def __init__(self):
54325440
pass
54335441

54345442

54355443
class LoRAXFormersAttnProcessor:
5444+
r"""
5445+
Processor for implementing attention with LoRA using xFormers.
5446+
"""
5447+
54365448
def __init__(self):
54375449
pass
54385450

54395451

54405452
class LoRAAttnAddedKVProcessor:
5453+
r"""
5454+
Processor for implementing attention with LoRA with extra learnable key and value matrices for the text encoder.
5455+
"""
5456+
54415457
def __init__(self):
54425458
pass
54435459

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import torch
2626
from huggingface_hub.utils import EntryNotFoundError
2727

28-
from ..quantizers.quantization_config import QuantizationMethod
2928
from ..utils import (
3029
SAFE_WEIGHTS_INDEX_NAME,
3130
SAFETENSORS_FILE_EXTENSION,
@@ -182,7 +181,6 @@ def load_model_dict_into_meta(
182181
device = device or torch.device("cpu")
183182
dtype = dtype or torch.float32
184183
is_quantized = hf_quantizer is not None
185-
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
186184

187185
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
188186
empty_state_dict = model.state_dict()
@@ -215,12 +213,12 @@ def load_model_dict_into_meta(
215213
# bnb params are flattened.
216214
if empty_state_dict[param_name].shape != param.shape:
217215
if (
218-
is_quant_method_bnb
216+
is_quantized
219217
and hf_quantizer.pre_quantized
220218
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
221219
):
222220
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
223-
elif not is_quant_method_bnb:
221+
else:
224222
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
225223
raise ValueError(
226224
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."

src/diffusers/models/modeling_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
700700
hf_quantizer = None
701701

702702
if hf_quantizer is not None:
703-
if device_map is not None:
703+
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
704+
if is_bnb_quantization_method and device_map is not None:
704705
raise NotImplementedError(
705-
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
706+
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
706707
)
708+
707709
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
708710
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
709711

@@ -858,13 +860,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
858860
if device_map is None and not is_sharded:
859861
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
860862
# It would error out during the `validate_environment()` call above in the absence of cuda.
861-
is_quant_method_bnb = (
862-
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
863-
)
864863
if hf_quantizer is None:
865864
param_device = "cpu"
866865
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
867-
elif is_quant_method_bnb:
866+
else:
868867
param_device = torch.device(torch.cuda.current_device())
869868
state_dict = load_state_dict(model_file, variant=variant)
870869
model._convert_deprecated_attention_blocks(state_dict)

src/diffusers/models/unets/unet_1d_blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tens
217217
if self.upsample:
218218
hidden_states = self.upsample(hidden_states)
219219
if self.downsample:
220-
self.downsample = self.downsample(hidden_states)
220+
hidden_states = self.downsample(hidden_states)
221221

222222
return hidden_states
223223

0 commit comments

Comments
 (0)