Skip to content

Commit 8844a2b

Browse files
authored
Support attention quantization for diffusers >= 0.35.0 (NVIDIA#608)
## What does this PR do? **Type of change:** new feature **Overview:** ? Attention mechanism has changed from diffusers 0.35. Many model attentions are now subclass of a new Mixin class: AttentionModuleMixin, which is not a sub class of Attention To fix it, patch the mixin class by forcing to use native attention impl so the existing function monkey patch still work. ## Testing manual quant of Wan, Flux --------- Signed-off-by: Shengliang Xu <[email protected]>
1 parent 1524251 commit 8844a2b

File tree

3 files changed

+45
-16
lines changed

3 files changed

+45
-16
lines changed

examples/diffusers/quantization/quantize.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -760,8 +760,6 @@ def quantize_model(
760760
self.logger.info("Disabling specific quantizers...")
761761
mtq.disable_quantizer(backbone, model_filter_func)
762762

763-
mtq.print_quant_summary(backbone)
764-
765763
self.logger.info("Quantization completed successfully")
766764

767765

@@ -816,7 +814,6 @@ def export_onnx(
816814
backbone: torch.nn.Module,
817815
model_type: ModelType,
818816
quant_format: QuantFormat,
819-
quantize_mha: bool,
820817
) -> None:
821818
"""
822819
Export model to ONNX format.
@@ -831,7 +828,6 @@ def export_onnx(
831828
return
832829

833830
self.logger.info(f"Starting ONNX export to {self.config.onnx_dir}")
834-
check_conv_and_mha(backbone, quant_format == QuantFormat.FP4, quantize_mha)
835831

836832
if quant_format == QuantFormat.FP8 and self._has_conv_layers(backbone):
837833
self.logger.info(
@@ -1118,12 +1114,16 @@ def forward_loop(mod):
11181114

11191115
export_manager.save_checkpoint(backbone)
11201116

1117+
check_conv_and_mha(
1118+
backbone, quant_config.format == QuantFormat.FP4, quant_config.quantize_mha
1119+
)
1120+
mtq.print_quant_summary(backbone)
1121+
11211122
export_manager.export_onnx(
11221123
pipe,
11231124
backbone,
11241125
model_config.model_type,
11251126
quant_config.format,
1126-
quantize_mha=quant_config.quantize_mha,
11271127
)
11281128
logger.info(
11291129
f"Quantization process completed successfully! Time taken = {time.time() - s} seconds"

examples/diffusers/quantization/utils.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from diffusers.utils import load_image
2626

2727
import modelopt.torch.quantization as mtq
28+
from modelopt.torch.quantization.plugins.diffusers import AttentionModuleMixin
2829

2930
USE_PEFT = True
3031
try:
@@ -44,21 +45,24 @@ def filter_func_default(name: str) -> bool:
4445

4546

4647
def check_conv_and_mha(backbone, if_fp4, quantize_mha):
47-
for _, module in backbone.named_modules():
48+
for name, module in backbone.named_modules():
4849
if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)) and if_fp4:
4950
module.weight_quantizer.disable()
5051
module.input_quantizer.disable()
51-
elif isinstance(module, Attention):
52-
if not quantize_mha:
53-
continue
52+
53+
print(f"Disabled NVFP4 Conv layer quantization for layer {name}")
54+
55+
elif isinstance(module, (Attention, AttentionModuleMixin)):
5456
head_size = int(module.inner_dim / module.heads)
55-
module.q_bmm_quantizer.disable()
56-
module.k_bmm_quantizer.disable()
57-
module.v_bmm_quantizer.disable()
58-
module.softmax_quantizer.disable()
59-
module.bmm2_output_quantizer.disable()
60-
if head_size % 16 != 0:
57+
if not quantize_mha or head_size % 16 != 0:
58+
module.q_bmm_quantizer.disable()
59+
module.k_bmm_quantizer.disable()
60+
module.v_bmm_quantizer.disable()
61+
module.softmax_quantizer.disable()
62+
module.bmm2_output_quantizer.disable()
6163
setattr(module, "_disable_fp8_mha", True)
64+
65+
print(f"Disabled Attention layer quantization for layer {name}")
6266
else:
6367
setattr(module, "_disable_fp8_mha", False)
6468

modelopt/torch/quantization/plugins/diffusers.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,21 @@
2020
from types import ModuleType
2121
from typing import TYPE_CHECKING
2222

23+
import diffusers
2324
import onnx
2425
import torch
2526
from diffusers.models.attention_processor import Attention
2627
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28+
from packaging.version import parse as parse_version
29+
30+
if parse_version(diffusers.__version__) >= parse_version("0.35.0"):
31+
from diffusers.models.attention import AttentionModuleMixin
32+
from diffusers.models.attention_dispatch import AttentionBackendName, attention_backend
33+
from diffusers.models.transformers.transformer_flux import FluxAttention
34+
from diffusers.models.transformers.transformer_ltx import LTXAttention
35+
from diffusers.models.transformers.transformer_wan import WanAttention
36+
else:
37+
AttentionModuleMixin = type("_dummy_type_no_instance", (), {}) # pylint: disable=invalid-name
2738
from torch.autograd import Function
2839
from torch.nn import functional as F
2940
from torch.onnx import symbolic_helper
@@ -140,7 +151,7 @@ def _quantized_sdpa(self, *args, **kwargs):
140151

141152

142153
class _QuantAttention(_QuantFunctionalMixin):
143-
"""FP8 processor for performing attention-related computations."""
154+
"""Quantized processor for performing attention-related computations."""
144155

145156
_functionals_to_replace = [
146157
(torch, "bmm", _quantized_bmm),
@@ -167,6 +178,20 @@ def _setup(self):
167178
QuantModuleRegistry.register({Attention: "Attention"})(_QuantAttention)
168179

169180

181+
if AttentionModuleMixin.__module__.startswith(diffusers.__name__):
182+
183+
class _QuantAttentionModuleMixin(_QuantAttention):
184+
"""Quantized AttentionModuleMixin for performing attention-related computations."""
185+
186+
def forward(self, *args, **kwargs):
187+
with attention_backend(AttentionBackendName.NATIVE):
188+
return super().forward(*args, **kwargs)
189+
190+
QuantModuleRegistry.register({FluxAttention: "FluxAttention"})(_QuantAttentionModuleMixin)
191+
QuantModuleRegistry.register({WanAttention: "WanAttention"})(_QuantAttentionModuleMixin)
192+
QuantModuleRegistry.register({LTXAttention: "LTXAttention"})(_QuantAttentionModuleMixin)
193+
194+
170195
original_scaled_dot_product_attention = F.scaled_dot_product_attention
171196

172197

0 commit comments

Comments
 (0)