Skip to content

Commit 748a002

Browse files
committed
update
1 parent ee084a5 commit 748a002

File tree

7 files changed

+67
-27
lines changed

7 files changed

+67
-27
lines changed

docs/source/en/api/quantization.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
2828

2929
[[autodoc]] BitsAndBytesConfig
3030

31+
## TorchAoConfig
32+
33+
[[autodoc]] TorchAoConfig
34+
3135
## DiffusersQuantizer
3236

3337
[[autodoc]] quantizers.base.DiffusersQuantizer

docs/source/en/quantization/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be
3232

3333
## When to use what?
3434

35-
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
35+
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes` and `torchao`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# torchao
13+
14+
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
15+
16+
Before you begin, make sure you have Pytorch version 2.5, or above, and TorchAO installed:
17+
18+
```bash
19+
pip install -U torch torchao
20+
```
21+
22+
## Usage
23+
24+
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
25+
26+
## Usage
27+
28+
## Resources
29+
30+
- [TorchAO Quantization API]()
31+
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)

src/diffusers/models/model_loading_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import torch
2626
from huggingface_hub.utils import EntryNotFoundError
2727

28-
from ..quantizers.quantization_config import QuantizationMethod
2928
from ..utils import (
3029
SAFE_WEIGHTS_INDEX_NAME,
3130
SAFETENSORS_FILE_EXTENSION,

src/diffusers/models/modeling_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -829,9 +829,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
829829
if device_map is None and not is_sharded:
830830
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
831831
# It would error out during the `validate_environment()` call above in the absence of cuda.
832-
is_quant_method_bnb = (
833-
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
834-
)
835832
if hf_quantizer is None:
836833
param_device = "cpu"
837834
# TODO (sayakpaul, SunMarc): remove this after model loading refactor

src/diffusers/quantizers/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from typing import Dict, Optional, Union
2020

2121
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
22-
from .torchao import TorchAoHfQuantizer
2322
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
23+
from .torchao import TorchAoHfQuantizer
2424

2525

2626
AUTO_QUANTIZER_MAPPING = {

src/diffusers/quantizers/quantization_config.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -484,32 +484,25 @@ def _get_torchao_quant_type_to_method(cls):
484484
"int4": int4_weight_only,
485485
"int4wo": int4_weight_only,
486486
"int4_weight_only": int4_weight_only,
487-
"int4_a16w4": int4_weight_only,
488487
# int4 weight + int8 activation
489488
"int4dq": int8_dynamic_activation_int4_weight,
490489
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
491-
"int4_a8w4": int8_dynamic_activation_int4_weight,
492490
}
493491

494492
INT8_QUANTIZATION_TYPES = {
495493
# int8 weight + bfloat16/float16 activation
496494
"int8": int8_weight_only,
497495
"int8wo": int8_weight_only,
498496
"int8_weight_only": int8_weight_only,
499-
"int8_a16w8": int8_weight_only,
500497
# int8 weight + int8 activation
501498
"int8dq": int8_dynamic_activation_int8_weight,
502499
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
503-
"int8_a8w8": int8_dynamic_activation_int8_weight,
504500
}
505501

506502
def generate_float8dq_types(dtype: torch.dtype):
507503
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
508504
types = {}
509505

510-
types[f"float8dq_{name}_a8w8"] = partial(
511-
float8_dynamic_activation_float8_weight, activation_dtype=dtype, weight_dtype=dtype
512-
)
513506
for activation_granularity_cls in [PerTensor, PerRow]:
514507
for weight_granularity_cls in [PerTensor, PerRow]:
515508
activation_name = "t" if activation_granularity_cls is PerTensor else "r"
@@ -526,22 +519,15 @@ def generate_float8dq_types(dtype: torch.dtype):
526519
weight_dtype=dtype,
527520
granularity=(activation_granularity_cls(), weight_granularity_cls()),
528521
)
529-
types[f"float8dq_{name}_a{activation_name}w{weight_name}_a8w8"] = partial(
530-
float8_dynamic_activation_float8_weight,
531-
activation_dtype=dtype,
532-
weight_dtype=dtype,
533-
granularity=(activation_granularity_cls(), weight_granularity_cls()),
534-
)
535522

536523
return types
537524

538525
def generate_fpx_quantization_types(bits: int):
539526
types = {}
540527

541-
for ebits in range(1, bits):
528+
for ebits in range(0, bits):
542529
mbits = bits - ebits - 1
543530
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
544-
types[f"fp{bits}_e{ebits}m{mbits}_a16w{bits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
545531

546532
non_sign_bits = bits - 1
547533
default_ebits = (non_sign_bits + 1) // 2
@@ -550,20 +536,17 @@ def generate_fpx_quantization_types(bits: int):
550536

551537
return types
552538

553-
# TODO(aryan): handle cuda capability and torch 2.2/2.3
539+
# TODO(aryan): handle torch 2.2/2.3
554540
FLOATX_QUANTIZATION_TYPES = {
555541
# float8_e5m2 weight + bfloat16/float16 activation
556542
"float8": float8_weight_only,
557543
"float8_weight_only": float8_weight_only,
558544
"float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
559-
"float8_a16w8": float8_weight_only,
560545
"float8_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
561546
"float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
562-
"float8_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
563547
# float8_e4m3 weight + bfloat16/float16 activation
564548
"float8_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
565549
"float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
566-
"float8wo_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
567550
# float8_e5m2 weight + float8 activation (dynamic)
568551
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
569552
"float8dq": float8_dynamic_activation_float8_weight,
@@ -572,7 +555,6 @@ def generate_fpx_quantization_types(bits: int):
572555
activation_dtype=torch.float8_e5m2,
573556
weight_dtype=torch.float8_e5m2,
574557
),
575-
"float8_a8w8": float8_dynamic_activation_float8_weight,
576558
**generate_float8dq_types(torch.float8_e5m2),
577559
# float8_e4m3 weight + float8 activation (dynamic)
578560
"float8dq_e4m3": partial(
@@ -609,7 +591,6 @@ def generate_uintx_quantization_types(bits: int):
609591
types = {}
610592
types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
611593
types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
612-
types[f"uint{bits}_a16w{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits])
613594
return types
614595

615596
UINTX_QUANTIZATION_DTYPES = {
@@ -625,13 +606,41 @@ def generate_uintx_quantization_types(bits: int):
625606
**generate_uintx_quantization_types(8),
626607
}
627608

609+
SHORTHAND_QUANTIZATION_TYPES = {
610+
"int_a16w4": int4_weight_only,
611+
"int_a8w4": int8_dynamic_activation_int4_weight,
612+
"int_a16w8": int8_weight_only,
613+
"int_a8w8": int8_dynamic_activation_int8_weight,
614+
"uint_a16w1": partial(uintx_weight_only, dtype=torch.uint1),
615+
"uint_a16w2": partial(uintx_weight_only, dtype=torch.uint2),
616+
"uint_a16w3": partial(uintx_weight_only, dtype=torch.uint3),
617+
"uint_a16w4": partial(uintx_weight_only, dtype=torch.uint4),
618+
"uint_a16w5": partial(uintx_weight_only, dtype=torch.uint5),
619+
"uint_a16w6": partial(uintx_weight_only, dtype=torch.uint6),
620+
"uint_a16w7": partial(uintx_weight_only, dtype=torch.uint7),
621+
"uint_a16w8": partial(uintx_weight_only, dtype=torch.uint8),
622+
}
623+
SHORTHAND_FLOAT_QUANTIZATION_TYPES = {
624+
"float_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
625+
"float_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
626+
"float_a8w8": float8_dynamic_activation_float8_weight,
627+
"float_a16w3": partial(fpx_weight_only, ebits=2, mbits=0),
628+
"float_a16w4": partial(fpx_weight_only, ebits=2, mbits=1),
629+
"float_a16w5": partial(fpx_weight_only, ebits=3, mbits=1),
630+
"float_a16w6": partial(fpx_weight_only, ebits=3, mbits=2),
631+
"float_a16w7": partial(fpx_weight_only, ebits=4, mbits=2),
632+
"float_a16w8": partial(fpx_weight_only, ebits=5, mbits=2),
633+
}
634+
628635
QUANTIZATION_TYPES = {}
629636
QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
630637
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
631638
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
639+
QUANTIZATION_TYPES.update(SHORTHAND_QUANTIZATION_TYPES)
632640

633641
if cls._is_cuda_capability_atleast_8_9():
634642
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
643+
QUANTIZATION_TYPES.update(SHORTHAND_FLOAT_QUANTIZATION_TYPES)
635644

636645
return QUANTIZATION_TYPES
637646
else:

0 commit comments

Comments
 (0)