Skip to content

Commit 747bd7d

Browse files
committed
apply review suggestions
1 parent f9f1535 commit 747bd7d

File tree

3 files changed

+79
-62
lines changed

3 files changed

+79
-62
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
4848
image.save("output.png")
4949
```
5050

51+
Additionally, TorchAO supports an automatic quantization API exposed with [`autoquant`](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. This can directly be used with the underlying modeling components at the moment, but Diffusers will also expose an autoquant configuration option in the future.
52+
5153
## Resources
5254

5355
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)

src/diffusers/quantizers/quantization_config.py

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -436,14 +436,14 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = Non
436436
else:
437437
self.quant_type_kwargs = kwargs
438438

439-
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
440-
if self.quant_type not in _STR_TO_METHOD.keys():
439+
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
440+
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
441441
raise ValueError(
442442
f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the "
443443
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
444444
)
445445

446-
method = _STR_TO_METHOD[self.quant_type]
446+
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
447447
signature = inspect.signature(method)
448448
all_kwargs = {
449449
param.name
@@ -481,24 +481,6 @@ def _get_torchao_quant_type_to_method(cls):
481481
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
482482
from torchao.quantization.observer import PerRow, PerTensor
483483

484-
INT4_QUANTIZATION_TYPES = {
485-
# int4 weight + bfloat16/float16 activation
486-
"int4wo": int4_weight_only,
487-
"int4_weight_only": int4_weight_only,
488-
# int4 weight + int8 activation
489-
"int4dq": int8_dynamic_activation_int4_weight,
490-
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
491-
}
492-
493-
INT8_QUANTIZATION_TYPES = {
494-
# int8 weight + bfloat16/float16 activation
495-
"int8wo": int8_weight_only,
496-
"int8_weight_only": int8_weight_only,
497-
# int8 weight + int8 activation
498-
"int8dq": int8_dynamic_activation_int8_weight,
499-
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
500-
}
501-
502484
def generate_float8dq_types(dtype: torch.dtype):
503485
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
504486
types = {}
@@ -529,6 +511,41 @@ def generate_fpx_quantization_types(bits: int):
529511

530512
return types
531513

514+
def generate_uintx_quantization_types(bits: int):
515+
UINTX_TO_DTYPE = {
516+
1: torch.uint1,
517+
2: torch.uint2,
518+
3: torch.uint3,
519+
4: torch.uint4,
520+
5: torch.uint5,
521+
6: torch.uint6,
522+
7: torch.uint7,
523+
# 8: torch.uint8, # uint8 quantization is not supported
524+
}
525+
526+
types = {}
527+
types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
528+
types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
529+
return types
530+
531+
INT4_QUANTIZATION_TYPES = {
532+
# int4 weight + bfloat16/float16 activation
533+
"int4wo": int4_weight_only,
534+
"int4_weight_only": int4_weight_only,
535+
# int4 weight + int8 activation
536+
"int4dq": int8_dynamic_activation_int4_weight,
537+
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
538+
}
539+
540+
INT8_QUANTIZATION_TYPES = {
541+
# int8 weight + bfloat16/float16 activation
542+
"int8wo": int8_weight_only,
543+
"int8_weight_only": int8_weight_only,
544+
# int8 weight + int8 activation
545+
"int8dq": int8_dynamic_activation_int8_weight,
546+
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
547+
}
548+
532549
# TODO(aryan): handle torch 2.2/2.3
533550
FLOATX_QUANTIZATION_TYPES = {
534551
# float8_e5m2 weight + bfloat16/float16 activation
@@ -540,7 +557,7 @@ def generate_fpx_quantization_types(bits: int):
540557
# float8_e5m2 weight + float8 activation (dynamic)
541558
"float8dq": float8_dynamic_activation_float8_weight,
542559
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
543-
# ===== Matrix multiplication is not supported in float8_e5m2 so the following error out.
560+
# ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out.
544561
# However, changing activation_dtype=torch.float8_e4m3 might work here =====
545562
# "float8dq_e5m2": partial(
546563
# float8_dynamic_activation_float8_weight,
@@ -566,27 +583,8 @@ def generate_fpx_quantization_types(bits: int):
566583
**generate_fpx_quantization_types(5),
567584
**generate_fpx_quantization_types(6),
568585
**generate_fpx_quantization_types(7),
569-
# ===== Errors out with "torch.cat(): expected a non-empty list of Tensors" =====
570-
# **generate_fpx_quantization_types(8),
571-
}
572-
573-
UINTX_TO_DTYPE = {
574-
1: torch.uint1,
575-
2: torch.uint2,
576-
3: torch.uint3,
577-
4: torch.uint4,
578-
5: torch.uint5,
579-
6: torch.uint6,
580-
7: torch.uint7,
581-
# 8: torch.uint8, # uint8 quantization is not supported
582586
}
583587

584-
def generate_uintx_quantization_types(bits: int):
585-
types = {}
586-
types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
587-
types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
588-
return types
589-
590588
UINTX_QUANTIZATION_DTYPES = {
591589
"uintx": uintx_weight_only,
592590
"uintx_weight_only": uintx_weight_only,
@@ -614,6 +612,7 @@ def generate_uintx_quantization_types(bits: int):
614612
"uint_a16w7": partial(uintx_weight_only, dtype=torch.uint7),
615613
# "uint_a16w8": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
616614
}
615+
617616
SHORTHAND_FLOAT_QUANTIZATION_TYPES = {
618617
"float_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
619618
"float_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
@@ -645,8 +644,6 @@ def generate_uintx_quantization_types(bits: int):
645644
@staticmethod
646645
def _is_cuda_capability_atleast_8_9() -> bool:
647646
if not torch.cuda.is_available():
648-
if torch.mps.is_available():
649-
return False
650647
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
651648

652649
major, minor = torch.cuda.get_device_capability()
@@ -655,9 +652,23 @@ def _is_cuda_capability_atleast_8_9() -> bool:
655652
return major >= 9
656653

657654
def get_apply_tensor_subclass(self):
658-
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
659-
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
655+
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
656+
return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
660657

661658
def __repr__(self):
659+
r"""
660+
Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`:
661+
662+
```
663+
TorchAoConfig {
664+
"modules_to_not_convert": null,
665+
"quant_method": "torchao",
666+
"quant_type": "uint_a16w4",
667+
"quant_type_kwargs": {
668+
"group_size": 32
669+
}
670+
}
671+
```
672+
"""
662673
config_dict = self.to_dict()
663674
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@
4141

4242
logger = logging.get_logger(__name__)
4343

44+
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
45+
# At the moment, only int8 is supported for integer quantization dtypes.
46+
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
47+
# to support more quantization methods, such as intx_weight_only.
48+
torch.int8,
49+
torch.float8_e4m3fn,
50+
torch.float8_e5m2,
51+
torch.uint1,
52+
torch.uint2,
53+
torch.uint3,
54+
torch.uint4,
55+
torch.uint5,
56+
torch.uint6,
57+
torch.uint7,
58+
)
59+
4460

4561
def _quantization_type(weight):
4662
from torchao.dtypes import AffineQuantizedTensor
@@ -123,30 +139,15 @@ def update_torch_dtype(self, torch_dtype):
123139
return torch_dtype
124140

125141
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
126-
supported_dtypes = (
127-
# At the moment, only int8 is supported for integer quantization dtypes.
128-
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
129-
# to support more quantization methods, such as intx_weight_only.
130-
torch.int8,
131-
torch.float8_e4m3fn,
132-
torch.float8_e5m2,
133-
torch.uint1,
134-
torch.uint2,
135-
torch.uint3,
136-
torch.uint4,
137-
torch.uint5,
138-
torch.uint6,
139-
torch.uint7,
140-
)
141-
if isinstance(target_dtype, supported_dtypes):
142+
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
142143
return target_dtype
143144

144145
# We need one of the supported dtypes to be selected in order for accelerate to determine
145146
# the total size of modules/parameters for auto device placement. This method will not be
146147
# called when device_map is not "auto".
147148
raise ValueError(
148149
f"You are using `device_map='auto'` on a TorchAO quantized model but a suitable target dtype "
149-
f"could not be inferred. The supported target_dtypes are: {supported_dtypes}. If you think the "
150+
f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the "
150151
f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
151152
)
152153

@@ -190,10 +191,13 @@ def create_quantized_param(
190191
module, tensor_name = get_module_from_name(model, param_name)
191192

192193
if self.pre_quantized:
194+
# If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info
195+
# about AffineQuantizedTensor
193196
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
194197
if isinstance(module, nn.Linear):
195198
module.extra_repr = types.MethodType(_linear_extra_repr, module)
196199
else:
200+
# As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves
197201
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
198202
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
199203

0 commit comments

Comments
 (0)