Skip to content

Commit 64cbf11

Browse files
committed
torchao quantizer
1 parent 7ac6e28 commit 64cbf11

File tree

6 files changed

+516
-3
lines changed

6 files changed

+516
-3
lines changed

src/diffusers/quantizers/auto.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Dict, Optional, Union
2020

2121
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
22-
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
22+
from .quantization_config import BitsAndBytesConfig, TorchAoConfig, QuantizationConfigMixin, QuantizationMethod
2323

2424

2525
AUTO_QUANTIZER_MAPPING = {
@@ -30,6 +30,7 @@
3030
AUTO_QUANTIZATION_CONFIG_MAPPING = {
3131
"bitsandbytes_4bit": BitsAndBytesConfig,
3232
"bitsandbytes_8bit": BitsAndBytesConfig,
33+
"torchao": TorchAoConfig,
3334
}
3435

3536

src/diffusers/quantizers/quantization_config.py

Lines changed: 243 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222

2323
import copy
2424
import importlib.metadata
25+
import inspect
2526
import json
2627
import os
2728
from dataclasses import dataclass
2829
from enum import Enum
29-
from typing import Any, Dict, Union
30+
from functools import partial
31+
from typing import Any, Dict, List, Optional, Union
3032

3133
from packaging import version
3234

33-
from ..utils import is_torch_available, logging
35+
from ..utils import is_torch_available, is_torchao_available, logging
3436

3537

3638
if is_torch_available():
@@ -41,6 +43,7 @@
4143

4244
class QuantizationMethod(str, Enum):
4345
BITS_AND_BYTES = "bitsandbytes"
46+
TORCHAO = "torchao"
4447

4548

4649
@dataclass
@@ -389,3 +392,241 @@ def to_diff_dict(self) -> Dict[str, Any]:
389392
serializable_config_dict[key] = value
390393

391394
return serializable_config_dict
395+
396+
397+
@dataclass
398+
class TorchAoConfig(QuantizationConfigMixin):
399+
"""This is a config class for torchao quantization/sparsity techniques.
400+
401+
Args:
402+
quant_type (`str`):
403+
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
404+
modules_to_not_convert (`list`, *optional*, default to `None`):
405+
The list of modules to not quantize, useful for quantizing models that explicitly require to have
406+
some modules left in their original precision.
407+
kwargs (`Dict[str, Any]`, *optional*):
408+
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
409+
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
410+
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
411+
412+
Example:
413+
414+
```python
415+
TODO(aryan): update
416+
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
417+
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
418+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
419+
```
420+
"""
421+
422+
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
423+
self.quant_method = QuantizationMethod.TORCHAO
424+
self.quant_type = quant_type
425+
self.modules_to_not_convert = modules_to_not_convert
426+
427+
# When we load from serialized config, "quant_type_kwargs" will be the key
428+
if "quant_type_kwargs" in kwargs:
429+
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
430+
else:
431+
self.quant_type_kwargs = kwargs
432+
433+
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
434+
if self.quant_type not in _STR_TO_METHOD.keys():
435+
raise ValueError(
436+
f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the "
437+
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
438+
)
439+
440+
method = _STR_TO_METHOD[self.quant_type]
441+
signature = inspect.signature(method)
442+
all_kwargs = {
443+
param.name
444+
for param in signature.parameters.values()
445+
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
446+
}
447+
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
448+
449+
if len(unsupported_kwargs) > 0:
450+
raise ValueError(
451+
f"The quantization method \"{method}\" does not supported the following keyword arguments: "
452+
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
453+
)
454+
455+
@classmethod
456+
def _get_torchao_quant_type_to_method(cls):
457+
r"""
458+
Returns supported torchao quantization types with all commonly used notations.
459+
"""
460+
461+
if is_torchao_available():
462+
from torchao.quantization import (
463+
int4_weight_only,
464+
int8_dynamic_activation_int8_weight,
465+
int8_dynamic_activation_int4_weight,
466+
int8_weight_only,
467+
float8_dynamic_activation_float8_weight,
468+
float8_static_activation_float8_weight,
469+
float8_weight_only,
470+
fpx_weight_only,
471+
uintx_weight_only,
472+
)
473+
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
474+
from torchao.quantization.observer import PerRow, PerTensor
475+
476+
# TODO(aryan): Support autoquant and sparsify
477+
478+
INT4_QUANTIZATION_TYPES = {
479+
# int4 weight + bfloat16/float16 activation
480+
"int4": int4_weight_only,
481+
"int4wo": int4_weight_only,
482+
"int4_weight_only": int4_weight_only,
483+
"int4_a16w4": int4_weight_only,
484+
# int4 weight + int8 activation
485+
"int4dq": int8_dynamic_activation_int4_weight,
486+
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
487+
"int4_a8w4": int8_dynamic_activation_int4_weight,
488+
}
489+
490+
INT8_QUANTIZATION_TYPES = {
491+
# int8 weight + bfloat16/float16 activation
492+
"int8": int8_weight_only,
493+
"int8wo": int8_weight_only,
494+
"int8_weight_only": int8_weight_only,
495+
"int8_a16w8": int8_weight_only,
496+
# int8 weight + int8 activation
497+
"int8dq": int8_dynamic_activation_int8_weight,
498+
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
499+
"int8_a8w8": int8_dynamic_activation_int8_weight,
500+
}
501+
502+
def generate_float8dq_types(dtype: torch.dtype):
503+
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
504+
types = {}
505+
506+
types[f"float8dq_{name}_a8w8"] = partial(float8_dynamic_activation_float8_weight, activation_dtype=dtype, weight_dtype=dtype)
507+
for activation_granularity_cls in [PerTensor, PerRow]:
508+
for weight_granularity_cls in [PerTensor, PerRow]:
509+
activation_name = "t" if activation_granularity_cls is PerTensor else "r"
510+
weight_name = "t" if weight_granularity_cls is PerTensor else "r"
511+
# The a{activation_name}w{weight_name} is a made up name for convenience of testing things.
512+
# It suffixes with for different granularities (activation granularity, weight granularity):
513+
# - atwt: PerTensor(), PerTensor()
514+
# - atwr: PerTensor(), PerRow()
515+
# - arwt: PerRow(), PerTensor()
516+
# - arwr: PerRow(), PerRow()
517+
types[f"float8dq_{name}_a{activation_name}w{weight_name}"] = partial(
518+
float8_dynamic_activation_float8_weight,
519+
activation_dtype=dtype,
520+
weight_dtype=dtype,
521+
granularity=(activation_granularity_cls(), weight_granularity_cls()),
522+
)
523+
types[f"float8dq_{name}_a{activation_name}w{weight_name}_a8w8"] = partial(
524+
float8_dynamic_activation_float8_weight,
525+
activation_dtype=dtype,
526+
weight_dtype=dtype,
527+
granularity=(activation_granularity_cls(), weight_granularity_cls()),
528+
)
529+
530+
return types
531+
532+
def generate_fpx_quantization_types(bits: int):
533+
types = {}
534+
535+
for ebits in range(1, bits):
536+
mbits = bits - ebits - 1
537+
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
538+
types[f"fp{bits}_e{ebits}m{mbits}_a16w{bits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
539+
540+
non_sign_bits = bits - 1
541+
default_ebits = (non_sign_bits + 1) // 2
542+
default_mbits = non_sign_bits - default_ebits
543+
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
544+
545+
return types
546+
547+
# TODO(aryan): handle cuda capability and torch 2.2/2.3
548+
FLOATX_QUANTIZATION_TYPES = {
549+
# float8_e5m2 weight + bfloat16/float16 activation
550+
"float8": float8_weight_only,
551+
"float8_weight_only": float8_weight_only,
552+
"float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
553+
"float8_a16w8": float8_weight_only,
554+
"float8_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
555+
"float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
556+
"float8_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
557+
# float8_e4m3 weight + bfloat16/float16 activation
558+
"float8_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
559+
"float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
560+
"float8wo_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
561+
# float8_e5m2 weight + float8 activation (dynamic)
562+
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
563+
"float8dq": float8_dynamic_activation_float8_weight,
564+
"float8dq_e5m2": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e5m2, weight_dtype=torch.float8_e5m2),
565+
"float8_a8w8": float8_dynamic_activation_float8_weight,
566+
**generate_float8dq_types(torch.float8_e5m2),
567+
# float8_e4m3 weight + float8 activation (dynamic)
568+
"float8dq_e4m3": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn),
569+
**generate_float8dq_types(torch.float8_e4m3fn),
570+
# float8 weight + float8 activation (static)
571+
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
572+
"float8sq": float8_static_activation_float8_weight,
573+
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
574+
# fpx weight + bfloat16/float16 activation
575+
**generate_fpx_quantization_types(3),
576+
**generate_fpx_quantization_types(4),
577+
**generate_fpx_quantization_types(5),
578+
**generate_fpx_quantization_types(6),
579+
**generate_fpx_quantization_types(7),
580+
**generate_fpx_quantization_types(8),
581+
}
582+
583+
UINTX_TO_DTYPE = {
584+
1: torch.uint1,
585+
2: torch.uint2,
586+
3: torch.uint3,
587+
4: torch.uint4,
588+
5: torch.uint5,
589+
6: torch.uint6,
590+
7: torch.uint7,
591+
8: torch.uint8,
592+
}
593+
594+
def generate_uintx_quantization_types(bits: int):
595+
types = {}
596+
types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
597+
types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
598+
types[f"uint{bits}_a16w{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
599+
return types
600+
601+
UINTX_QUANTIZATION_DTYPES = {
602+
"uintx": uintx_weight_only,
603+
"uintx_weight_only": uintx_weight_only,
604+
**generate_uintx_quantization_types(1),
605+
**generate_uintx_quantization_types(2),
606+
**generate_uintx_quantization_types(3),
607+
**generate_uintx_quantization_types(4),
608+
**generate_uintx_quantization_types(5),
609+
**generate_uintx_quantization_types(6),
610+
**generate_uintx_quantization_types(7),
611+
**generate_uintx_quantization_types(8),
612+
}
613+
614+
QUANTIZATION_TYPES = {}
615+
QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
616+
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
617+
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
618+
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
619+
620+
return QUANTIZATION_TYPES
621+
else:
622+
raise ValueError(
623+
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
624+
)
625+
626+
def get_apply_tensor_subclass(self):
627+
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
628+
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
629+
630+
def __repr__(self):
631+
config_dict = self.to_dict()
632+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .torchao_quantizer import TorchAoHfQuantizer

0 commit comments

Comments
 (0)