From 64cbf116bc49dbdcd39587df5ae5289934dde850 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 24 Nov 2024 02:38:19 +0100 Subject: [PATCH 01/27] torchao quantizer --- src/diffusers/quantizers/auto.py | 3 +- .../quantizers/quantization_config.py | 245 +++++++++++++++++- src/diffusers/quantizers/torchao/__init__.py | 15 ++ .../quantizers/torchao/torchao_quantizer.py | 236 +++++++++++++++++ src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 19 ++ 6 files changed, 516 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/quantizers/torchao/__init__.py create mode 100644 src/diffusers/quantizers/torchao/torchao_quantizer.py diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 97cbcdc0e53f..2857c23d16da 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Union from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer -from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod +from .quantization_config import BitsAndBytesConfig, TorchAoConfig, QuantizationConfigMixin, QuantizationMethod AUTO_QUANTIZER_MAPPING = { @@ -30,6 +30,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = { "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, + "torchao": TorchAoConfig, } diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index f521c5d717d6..93bd04ac649c 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -22,15 +22,17 @@ import copy import importlib.metadata +import inspect import json import os from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Union +from functools import partial +from typing import Any, Dict, List, Optional, Union from packaging import version -from ..utils import is_torch_available, logging +from ..utils import is_torch_available, is_torchao_available, logging if is_torch_available(): @@ -41,6 +43,7 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" + TORCHAO = "torchao" @dataclass @@ -389,3 +392,241 @@ def to_diff_dict(self) -> Dict[str, Any]: serializable_config_dict[key] = value return serializable_config_dict + + +@dataclass +class TorchAoConfig(QuantizationConfigMixin): + """This is a config class for torchao quantization/sparsity techniques. + + Args: + quant_type (`str`): + The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`. + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + kwargs (`Dict[str, Any]`, *optional*): + The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments + `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in + https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques + + Example: + + ```python + TODO(aryan): update + quantization_config = TorchAoConfig("int4_weight_only", group_size=32) + # int4_weight_only quant is only working with *torch.bfloat16* dtype right now + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) + ``` + """ + + def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs): + self.quant_method = QuantizationMethod.TORCHAO + self.quant_type = quant_type + self.modules_to_not_convert = modules_to_not_convert + + # When we load from serialized config, "quant_type_kwargs" will be the key + if "quant_type_kwargs" in kwargs: + self.quant_type_kwargs = kwargs["quant_type_kwargs"] + else: + self.quant_type_kwargs = kwargs + + _STR_TO_METHOD = self._get_torchao_quant_type_to_method() + if self.quant_type not in _STR_TO_METHOD.keys(): + raise ValueError( + f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the " + f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + method = _STR_TO_METHOD[self.quant_type] + signature = inspect.signature(method) + all_kwargs = { + param.name + for param in signature.parameters.values() + if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] + } + unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs) + + if len(unsupported_kwargs) > 0: + raise ValueError( + f"The quantization method \"{method}\" does not supported the following keyword arguments: " + f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." + ) + + @classmethod + def _get_torchao_quant_type_to_method(cls): + r""" + Returns supported torchao quantization types with all commonly used notations. + """ + + if is_torchao_available(): + from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_dynamic_activation_int4_weight, + int8_weight_only, + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, + fpx_weight_only, + uintx_weight_only, + ) + # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers + from torchao.quantization.observer import PerRow, PerTensor + + # TODO(aryan): Support autoquant and sparsify + + INT4_QUANTIZATION_TYPES = { + # int4 weight + bfloat16/float16 activation + "int4": int4_weight_only, + "int4wo": int4_weight_only, + "int4_weight_only": int4_weight_only, + "int4_a16w4": int4_weight_only, + # int4 weight + int8 activation + "int4dq": int8_dynamic_activation_int4_weight, + "int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight, + "int4_a8w4": int8_dynamic_activation_int4_weight, + } + + INT8_QUANTIZATION_TYPES = { + # int8 weight + bfloat16/float16 activation + "int8": int8_weight_only, + "int8wo": int8_weight_only, + "int8_weight_only": int8_weight_only, + "int8_a16w8": int8_weight_only, + # int8 weight + int8 activation + "int8dq": int8_dynamic_activation_int8_weight, + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + "int8_a8w8": int8_dynamic_activation_int8_weight, + } + + def generate_float8dq_types(dtype: torch.dtype): + name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3" + types = {} + + types[f"float8dq_{name}_a8w8"] = partial(float8_dynamic_activation_float8_weight, activation_dtype=dtype, weight_dtype=dtype) + for activation_granularity_cls in [PerTensor, PerRow]: + for weight_granularity_cls in [PerTensor, PerRow]: + activation_name = "t" if activation_granularity_cls is PerTensor else "r" + weight_name = "t" if weight_granularity_cls is PerTensor else "r" + # The a{activation_name}w{weight_name} is a made up name for convenience of testing things. + # It suffixes with for different granularities (activation granularity, weight granularity): + # - atwt: PerTensor(), PerTensor() + # - atwr: PerTensor(), PerRow() + # - arwt: PerRow(), PerTensor() + # - arwr: PerRow(), PerRow() + types[f"float8dq_{name}_a{activation_name}w{weight_name}"] = partial( + float8_dynamic_activation_float8_weight, + activation_dtype=dtype, + weight_dtype=dtype, + granularity=(activation_granularity_cls(), weight_granularity_cls()), + ) + types[f"float8dq_{name}_a{activation_name}w{weight_name}_a8w8"] = partial( + float8_dynamic_activation_float8_weight, + activation_dtype=dtype, + weight_dtype=dtype, + granularity=(activation_granularity_cls(), weight_granularity_cls()), + ) + + return types + + def generate_fpx_quantization_types(bits: int): + types = {} + + for ebits in range(1, bits): + mbits = bits - ebits - 1 + types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) + types[f"fp{bits}_e{ebits}m{mbits}_a16w{bits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) + + non_sign_bits = bits - 1 + default_ebits = (non_sign_bits + 1) // 2 + default_mbits = non_sign_bits - default_ebits + types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) + + return types + + # TODO(aryan): handle cuda capability and torch 2.2/2.3 + FLOATX_QUANTIZATION_TYPES = { + # float8_e5m2 weight + bfloat16/float16 activation + "float8": float8_weight_only, + "float8_weight_only": float8_weight_only, + "float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + "float8_a16w8": float8_weight_only, + "float8_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + "float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + "float8_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + # float8_e4m3 weight + bfloat16/float16 activation + "float8_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), + "float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), + "float8wo_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), + # float8_e5m2 weight + float8 activation (dynamic) + "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, + "float8dq": float8_dynamic_activation_float8_weight, + "float8dq_e5m2": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e5m2, weight_dtype=torch.float8_e5m2), + "float8_a8w8": float8_dynamic_activation_float8_weight, + **generate_float8dq_types(torch.float8_e5m2), + # float8_e4m3 weight + float8 activation (dynamic) + "float8dq_e4m3": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn), + **generate_float8dq_types(torch.float8_e4m3fn), + # float8 weight + float8 activation (static) + "float8_static_activation_float8_weight": float8_static_activation_float8_weight, + "float8sq": float8_static_activation_float8_weight, + # For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly + # fpx weight + bfloat16/float16 activation + **generate_fpx_quantization_types(3), + **generate_fpx_quantization_types(4), + **generate_fpx_quantization_types(5), + **generate_fpx_quantization_types(6), + **generate_fpx_quantization_types(7), + **generate_fpx_quantization_types(8), + } + + UINTX_TO_DTYPE = { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + 8: torch.uint8, + } + + def generate_uintx_quantization_types(bits: int): + types = {} + types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) + types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) + types[f"uint{bits}_a16w{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) + return types + + UINTX_QUANTIZATION_DTYPES = { + "uintx": uintx_weight_only, + "uintx_weight_only": uintx_weight_only, + **generate_uintx_quantization_types(1), + **generate_uintx_quantization_types(2), + **generate_uintx_quantization_types(3), + **generate_uintx_quantization_types(4), + **generate_uintx_quantization_types(5), + **generate_uintx_quantization_types(6), + **generate_uintx_quantization_types(7), + **generate_uintx_quantization_types(8), + } + + QUANTIZATION_TYPES = {} + QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES) + QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) + QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) + QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) + + return QUANTIZATION_TYPES + else: + raise ValueError( + "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" + ) + + def get_apply_tensor_subclass(self): + _STR_TO_METHOD = self._get_torchao_quant_type_to_method() + return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs) + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" diff --git a/src/diffusers/quantizers/torchao/__init__.py b/src/diffusers/quantizers/torchao/__init__.py new file mode 100644 index 000000000000..09e6a19d4df0 --- /dev/null +++ b/src/diffusers/quantizers/torchao/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .torchao_quantizer import TorchAoHfQuantizer diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py new file mode 100644 index 000000000000..275ffe0119aa --- /dev/null +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -0,0 +1,236 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/src/transformers/quantizers/quantizer_torchao.py +""" + +import importlib +import types +from packaging import version +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from ..base import DiffusersQuantizer +from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + import torch.nn as nn + + +logger = logging.get_logger(__name__) + + +def _quantization_type(weight): + from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + + if isinstance(weight, AffineQuantizedTensor): + return f"{weight.__class__.__name__}({weight._quantization_type()})" + + if isinstance(weight, LinearActivationQuantizedTensor): + return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" + + +def _linear_extra_repr(self): + weight = _quantization_type(self.weight) + if weight is None: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" + else: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}" + + +class TorchAoHfQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for TorchAO: https://github.com/pytorch/ao/. + """ + + requires_calibration = False + required_packages = ["torchao"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_torchao_available(): + raise ImportError("Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`") + + self.offload = False + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + if self.pre_quantized: + raise ValueError( + "You are attempting to perform cpu/disk offload with a pre-quantized torchao model " + "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." + ) + else: + self.offload = True + + if self.pre_quantized: + weights_only = kwargs.get("weights_only", None) + if weights_only: + torch_version = version.parse(importlib.metadata.version("torch")) + if torch_version < version.parse("2.5.0"): + # TODO(aryan): TorchAO is compatible with Pytorch 2.2 for certain quantization types. Try to see if we can support it + raise RuntimeError( + f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}." + ) + + def update_torch_dtype(self, torch_dtype): + quant_type = self.quantization_config.quant_type + + if quant_type.startswith("int") or quant_type.startswith("uint"): + if torch_dtype is not None and torch_dtype != torch.bfloat16: + logger.warning( + f"Setting torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." + ) + + if torch_dtype is None: + # we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op + logger.info( + f"Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` " + f"to enable model loading in different precisions. Pass your own `torch_dtype` to specify the " + f"dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning." + ) + torch_dtype = torch.bfloat16 + + return torch_dtype + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + supported_dtypes = (torch.int8, torch.float8_e4m3fn, torch.float8_e5m2, torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8) + if isinstance(target_dtype, supported_dtypes): + return target_dtype + + raise ValueError( + f"You are using `device_map='auto'` on a TorchAO quantized model but a suitable target dtype " + f"could not be inferred. The supported target_dtypes are: {supported_dtypes}. If you think the " + f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # TODO(aryan): Not sure what to do here @sayakpaul, @DN6 + # need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128 + max_memory = {key: val * 0.9 for key, val in max_memory.items()} + return max_memory + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + param_device = kwargs.pop("param_device", None) + # check if the param_name is not in self.modules_to_not_convert + if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): + return False + elif param_device == "cpu" and self.offload: + # We don't quantize weights that we offload + return False + else: + # we only quantize the weight of nn.Linear + module, tensor_name = get_module_from_name(model, param_name) + return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: List[str], + ): + r""" + Each nn.Linear layer that needs to be quantized is processsed here. + First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module. + """ + from torchao.quantization import quantize_ + + module, tensor_name = get_module_from_name(model, param_name) + + if self.pre_quantized: + module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + else: + module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + quantize_(module, self.quantization_config.get_apply_tensor_subclass()) + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "ModelMixin"): + return model + + def is_serializable(self, safe_serialization=None): + # TODO(aryan): needs to be tested + if safe_serialization: + logger.warning( + "torchao quantized model does not support safe serialization, please set `safe_serialization` to False." + ) + return False + + _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( + "0.25.0" + ) + + if not _is_torchao_serializable: + logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ") + + if self.offload and self.quantization_config.modules_to_not_convert is None: + logger.warning( + "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them." + "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config." + ) + return False + + return _is_torchao_serializable + + @property + def is_trainable(self): + # TODO(aryan): needs testing + return self.quantization_config.quant_type.startswith("int8") diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c8f64adf3e8a..1e22cddfc120 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -82,6 +82,7 @@ is_sentencepiece_available, is_tensorboard_available, is_timm_available, + is_torchao_available, is_torch_available, is_torch_npu_available, is_torch_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f1323bf00ea4..274bda74a391 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -340,6 +340,15 @@ def is_timm_available(): _imageio_available = False +_is_torchao_available = importlib.util.find_spec("torchao") is not None +if _is_torchao_available: + try: + _torchao_version = importlib_metadata.version("torchao") + logger.debug(f"Successfully import gguf version {_torchao_version}") + except importlib_metadata.PackageNotFoundError: + _is_torchao_available = False + + def is_torch_available(): return _torch_available @@ -460,6 +469,10 @@ def is_imageio_available(): return _imageio_available +def is_torchao_available(): + return _is_torchao_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -593,6 +606,11 @@ def is_imageio_available(): {0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg` """ +# docstyle-ignore +TORCHAO_IMPORT_ERROR = """ +{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao` +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -618,6 +636,7 @@ def is_imageio_available(): ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)), ] ) From b78a36cb61e853aa2527c077237e24a03a3b8290 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 24 Nov 2024 02:38:33 +0100 Subject: [PATCH 02/27] make style --- src/diffusers/quantizers/auto.py | 2 +- .../quantizers/quantization_config.py | 55 ++++++++++------- .../quantizers/torchao/torchao_quantizer.py | 60 ++++++++++++------- src/diffusers/utils/__init__.py | 2 +- 4 files changed, 75 insertions(+), 44 deletions(-) diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 2857c23d16da..8c1f3acaede5 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Union from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer -from .quantization_config import BitsAndBytesConfig, TorchAoConfig, QuantizationConfigMixin, QuantizationMethod +from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig AUTO_QUANTIZER_MAPPING = { diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 93bd04ac649c..c2e5ab2e513c 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -400,13 +400,15 @@ class TorchAoConfig(QuantizationConfigMixin): Args: quant_type (`str`): - The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`. + The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and + `int8_dynamic_activation_int8_weight`. modules_to_not_convert (`list`, *optional*, default to `None`): - The list of modules to not quantize, useful for quantizing models that explicitly require to have - some modules left in their original precision. + The list of modules to not quantize, useful for quantizing models that explicitly require to have some + modules left in their original precision. kwargs (`Dict[str, Any]`, *optional*): - The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments - `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in + The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization + supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and + documentation of arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques Example: @@ -415,7 +417,9 @@ class TorchAoConfig(QuantizationConfigMixin): TODO(aryan): update quantization_config = TorchAoConfig("int4_weight_only", group_size=32) # int4_weight_only quant is only working with *torch.bfloat16* dtype right now - model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) + model = AutoModelForCausalLM.from_pretrained( + model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config + ) ``` """ @@ -423,7 +427,7 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = Non self.quant_method = QuantizationMethod.TORCHAO self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert - + # When we load from serialized config, "quant_type_kwargs" will be the key if "quant_type_kwargs" in kwargs: self.quant_type_kwargs = kwargs["quant_type_kwargs"] @@ -448,7 +452,7 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = Non if len(unsupported_kwargs) > 0: raise ValueError( - f"The quantization method \"{method}\" does not supported the following keyword arguments: " + f'The quantization method "{method}" does not supported the following keyword arguments: ' f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." ) @@ -460,16 +464,17 @@ def _get_torchao_quant_type_to_method(cls): if is_torchao_available(): from torchao.quantization import ( - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_dynamic_activation_int4_weight, - int8_weight_only, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, fpx_weight_only, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, uintx_weight_only, ) + # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers from torchao.quantization.observer import PerRow, PerTensor @@ -502,8 +507,10 @@ def _get_torchao_quant_type_to_method(cls): def generate_float8dq_types(dtype: torch.dtype): name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3" types = {} - - types[f"float8dq_{name}_a8w8"] = partial(float8_dynamic_activation_float8_weight, activation_dtype=dtype, weight_dtype=dtype) + + types[f"float8dq_{name}_a8w8"] = partial( + float8_dynamic_activation_float8_weight, activation_dtype=dtype, weight_dtype=dtype + ) for activation_granularity_cls in [PerTensor, PerRow]: for weight_granularity_cls in [PerTensor, PerRow]: activation_name = "t" if activation_granularity_cls is PerTensor else "r" @@ -526,22 +533,22 @@ def generate_float8dq_types(dtype: torch.dtype): weight_dtype=dtype, granularity=(activation_granularity_cls(), weight_granularity_cls()), ) - + return types def generate_fpx_quantization_types(bits: int): types = {} - + for ebits in range(1, bits): mbits = bits - ebits - 1 types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) types[f"fp{bits}_e{ebits}m{mbits}_a16w{bits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) - + non_sign_bits = bits - 1 default_ebits = (non_sign_bits + 1) // 2 default_mbits = non_sign_bits - default_ebits types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) - + return types # TODO(aryan): handle cuda capability and torch 2.2/2.3 @@ -561,11 +568,19 @@ def generate_fpx_quantization_types(bits: int): # float8_e5m2 weight + float8 activation (dynamic) "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, "float8dq": float8_dynamic_activation_float8_weight, - "float8dq_e5m2": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e5m2, weight_dtype=torch.float8_e5m2), + "float8dq_e5m2": partial( + float8_dynamic_activation_float8_weight, + activation_dtype=torch.float8_e5m2, + weight_dtype=torch.float8_e5m2, + ), "float8_a8w8": float8_dynamic_activation_float8_weight, **generate_float8dq_types(torch.float8_e5m2), # float8_e4m3 weight + float8 activation (dynamic) - "float8dq_e4m3": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn), + "float8dq_e4m3": partial( + float8_dynamic_activation_float8_weight, + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + ), **generate_float8dq_types(torch.float8_e4m3fn), # float8 weight + float8 activation (static) "float8_static_activation_float8_weight": float8_static_activation_float8_weight, diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 275ffe0119aa..46e89c166169 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -19,11 +19,13 @@ import importlib import types -from packaging import version from typing import TYPE_CHECKING, Any, Dict, List, Union -from ..base import DiffusersQuantizer +from packaging import version + from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging +from ..base import DiffusersQuantizer + if TYPE_CHECKING: from ...models.modeling_utils import ModelMixin @@ -69,10 +71,12 @@ def __init__(self, quantization_config, **kwargs): def validate_environment(self, *args, **kwargs): if not is_torchao_available(): - raise ImportError("Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`") + raise ImportError( + "Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`" + ) self.offload = False - + device_map = kwargs.get("device_map", None) if isinstance(device_map, dict): if "cpu" in device_map.values() or "disk" in device_map.values(): @@ -83,7 +87,7 @@ def validate_environment(self, *args, **kwargs): ) else: self.offload = True - + if self.pre_quantized: weights_only = kwargs.get("weights_only", None) if weights_only: @@ -96,29 +100,41 @@ def validate_environment(self, *args, **kwargs): def update_torch_dtype(self, torch_dtype): quant_type = self.quantization_config.quant_type - + if quant_type.startswith("int") or quant_type.startswith("uint"): if torch_dtype is not None and torch_dtype != torch.bfloat16: logger.warning( f"Setting torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." ) - + if torch_dtype is None: # we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op logger.info( - f"Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` " - f"to enable model loading in different precisions. Pass your own `torch_dtype` to specify the " - f"dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning." + "Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` " + "to enable model loading in different precisions. Pass your own `torch_dtype` to specify the " + "dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning." ) torch_dtype = torch.bfloat16 - + return torch_dtype def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": - supported_dtypes = (torch.int8, torch.float8_e4m3fn, torch.float8_e5m2, torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8) + supported_dtypes = ( + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + torch.uint8, + ) if isinstance(target_dtype, supported_dtypes): return target_dtype - + raise ValueError( f"You are using `device_map='auto'` on a TorchAO quantized model but a suitable target dtype " f"could not be inferred. The supported target_dtypes are: {supported_dtypes}. If you think the " @@ -161,8 +177,8 @@ def create_quantized_param( unexpected_keys: List[str], ): r""" - Each nn.Linear layer that needs to be quantized is processsed here. - First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module. + Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor, + then we move it to the target device. Finally, we quantize the module. """ from torchao.quantization import quantize_ @@ -187,14 +203,14 @@ def _process_model_before_weight_loading( if not isinstance(self.modules_to_not_convert, list): self.modules_to_not_convert = [self.modules_to_not_convert] - + self.modules_to_not_convert.extend(keep_in_fp32_modules) # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` if isinstance(device_map, dict) and len(device_map.keys()) > 1: keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] self.modules_to_not_convert.extend(keys_on_cpu) - + # Purge `None`. # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 # in case of diffusion transformer models. For language models and others alike, `lm_head` @@ -202,7 +218,7 @@ def _process_model_before_weight_loading( self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] model.config.quantization_config = self.quantization_config - + def _process_model_after_weight_loading(self, model: "ModelMixin"): return model @@ -213,21 +229,21 @@ def is_serializable(self, safe_serialization=None): "torchao quantized model does not support safe serialization, please set `safe_serialization` to False." ) return False - + _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( "0.25.0" ) - + if not _is_torchao_serializable: logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ") - + if self.offload and self.quantization_config.modules_to_not_convert is None: logger.warning( "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them." "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config." ) return False - + return _is_torchao_serializable @property diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 1e22cddfc120..3b04fdf27014 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -82,11 +82,11 @@ is_sentencepiece_available, is_tensorboard_available, is_timm_available, - is_torchao_available, is_torch_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, + is_torchao_available, is_torchsde_available, is_torchvision_available, is_transformers_available, From 355509ed91d74d46d8993fc734fd3246990833ee Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 24 Nov 2024 02:42:46 +0100 Subject: [PATCH 03/27] update --- src/diffusers/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a4749af5f61b..eb3dfd2ff53b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,7 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers.quantization_config": ["BitsAndBytesConfig"], + "quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -551,7 +551,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin - from .quantizers.quantization_config import BitsAndBytesConfig + from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig try: if not is_onnx_available(): From cbb0da49950206f6a8784be9a3e84ddb38d838bb Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 24 Nov 2024 06:21:34 +0100 Subject: [PATCH 04/27] update --- src/diffusers/models/model_loading_utils.py | 8 +++----- src/diffusers/models/modeling_utils.py | 2 +- src/diffusers/quantizers/auto.py | 2 ++ src/diffusers/quantizers/torchao/torchao_quantizer.py | 7 ++++--- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 932a94571107..a7453b45a4a8 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -176,11 +176,9 @@ def load_model_dict_into_meta( hf_quantizer=None, keep_in_fp32_modules=None, ) -> List[str]: - if hf_quantizer is None: - device = device or torch.device("cpu") + device = device or torch.device("cpu") dtype = dtype or torch.float32 is_quantized = hf_quantizer is not None - is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) empty_state_dict = model.state_dict() @@ -213,12 +211,12 @@ def load_model_dict_into_meta( # bnb params are flattened. if empty_state_dict[param_name].shape != param.shape: if ( - is_quant_method_bnb + is_quantized and hf_quantizer.pre_quantized and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) ): hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) - elif not is_quant_method_bnb: + else: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4a486fd4ce40..8701c952b453 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -835,7 +835,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if hf_quantizer is None: param_device = "cpu" # TODO (sayakpaul, SunMarc): remove this after model loading refactor - elif is_quant_method_bnb: + else: param_device = torch.cuda.current_device() state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 8c1f3acaede5..adc8d2393e26 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -19,12 +19,14 @@ from typing import Dict, Optional, Union from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer +from .torchao import TorchAoHfQuantizer from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig AUTO_QUANTIZER_MAPPING = { "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, + "torchao": TorchAoHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 46e89c166169..10522eaa390a 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -101,15 +101,16 @@ def validate_environment(self, *args, **kwargs): def update_torch_dtype(self, torch_dtype): quant_type = self.quantization_config.quant_type - if quant_type.startswith("int") or quant_type.startswith("uint"): + if quant_type.startswith("int"): if torch_dtype is not None and torch_dtype != torch.bfloat16: logger.warning( - f"Setting torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." + f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " + f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." ) if torch_dtype is None: # we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op - logger.info( + logger.warning( "Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` " "to enable model loading in different precisions. Pass your own `torch_dtype` to specify the " "dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning." From ee084a5101152d7ac1c2f8971218482d12247cd2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 24 Nov 2024 06:52:05 +0100 Subject: [PATCH 05/27] cuda capability check --- src/diffusers/quantizers/quantization_config.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index c2e5ab2e513c..dc987f0653f4 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -463,6 +463,7 @@ def _get_torchao_quant_type_to_method(cls): """ if is_torchao_available(): + # TODO(aryan): Support autoquant and sparsify from torchao.quantization import ( float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, @@ -478,8 +479,6 @@ def _get_torchao_quant_type_to_method(cls): # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers from torchao.quantization.observer import PerRow, PerTensor - # TODO(aryan): Support autoquant and sparsify - INT4_QUANTIZATION_TYPES = { # int4 weight + bfloat16/float16 activation "int4": int4_weight_only, @@ -629,15 +628,24 @@ def generate_uintx_quantization_types(bits: int): QUANTIZATION_TYPES = {} QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) - QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) + if cls._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) + return QUANTIZATION_TYPES else: raise ValueError( "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" ) + @staticmethod + def _is_cuda_capability_atleast_8_9() -> bool: + major, minor = torch.cuda.get_device_capability() + if major == 8: + return minor >= 9 + return major >= 9 + def get_apply_tensor_subclass(self): _STR_TO_METHOD = self._get_torchao_quant_type_to_method() return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs) From 748a0023ec7f8f8e8a692225575414fa1f89bd55 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 24 Nov 2024 10:46:50 +0100 Subject: [PATCH 06/27] update --- docs/source/en/api/quantization.md | 4 ++ docs/source/en/quantization/overview.md | 2 +- docs/source/en/quantization/torchao.md | 31 +++++++++++ src/diffusers/models/model_loading_utils.py | 1 - src/diffusers/models/modeling_utils.py | 3 -- src/diffusers/quantizers/auto.py | 2 +- .../quantizers/quantization_config.py | 51 +++++++++++-------- 7 files changed, 67 insertions(+), 27 deletions(-) create mode 100644 docs/source/en/quantization/torchao.md diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 2fbde9e707ea..18aadf3111bd 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui [[autodoc]] BitsAndBytesConfig +## TorchAoConfig + +[[autodoc]] TorchAoConfig + ## DiffusersQuantizer [[autodoc]] quantizers.base.DiffusersQuantizer diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index d8adbc85a259..99d381e3a537 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be ## When to use what? -This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file +This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes` and `torchao`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md new file mode 100644 index 000000000000..d10fb4b33cf2 --- /dev/null +++ b/docs/source/en/quantization/torchao.md @@ -0,0 +1,31 @@ + + +# torchao + +[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks). + +Before you begin, make sure you have Pytorch version 2.5, or above, and TorchAO installed: + +```bash +pip install -U torch torchao +``` + +## Usage + +Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. + +## Usage + +## Resources + +- [TorchAO Quantization API]() +- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index a7453b45a4a8..251d61fa56d9 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -25,7 +25,6 @@ import torch from huggingface_hub.utils import EntryNotFoundError -from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 8701c952b453..5c475c748e4a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -829,9 +829,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map is None and not is_sharded: # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. # It would error out during the `validate_environment()` call above in the absence of cuda. - is_quant_method_bnb = ( - getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES - ) if hf_quantizer is None: param_device = "cpu" # TODO (sayakpaul, SunMarc): remove this after model loading refactor diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index adc8d2393e26..098308ae0bdc 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -19,8 +19,8 @@ from typing import Dict, Optional, Union from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer -from .torchao import TorchAoHfQuantizer from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig +from .torchao import TorchAoHfQuantizer AUTO_QUANTIZER_MAPPING = { diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index dc987f0653f4..07f747819c70 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -484,11 +484,9 @@ def _get_torchao_quant_type_to_method(cls): "int4": int4_weight_only, "int4wo": int4_weight_only, "int4_weight_only": int4_weight_only, - "int4_a16w4": int4_weight_only, # int4 weight + int8 activation "int4dq": int8_dynamic_activation_int4_weight, "int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight, - "int4_a8w4": int8_dynamic_activation_int4_weight, } INT8_QUANTIZATION_TYPES = { @@ -496,20 +494,15 @@ def _get_torchao_quant_type_to_method(cls): "int8": int8_weight_only, "int8wo": int8_weight_only, "int8_weight_only": int8_weight_only, - "int8_a16w8": int8_weight_only, # int8 weight + int8 activation "int8dq": int8_dynamic_activation_int8_weight, "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, - "int8_a8w8": int8_dynamic_activation_int8_weight, } def generate_float8dq_types(dtype: torch.dtype): name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3" types = {} - types[f"float8dq_{name}_a8w8"] = partial( - float8_dynamic_activation_float8_weight, activation_dtype=dtype, weight_dtype=dtype - ) for activation_granularity_cls in [PerTensor, PerRow]: for weight_granularity_cls in [PerTensor, PerRow]: activation_name = "t" if activation_granularity_cls is PerTensor else "r" @@ -526,22 +519,15 @@ def generate_float8dq_types(dtype: torch.dtype): weight_dtype=dtype, granularity=(activation_granularity_cls(), weight_granularity_cls()), ) - types[f"float8dq_{name}_a{activation_name}w{weight_name}_a8w8"] = partial( - float8_dynamic_activation_float8_weight, - activation_dtype=dtype, - weight_dtype=dtype, - granularity=(activation_granularity_cls(), weight_granularity_cls()), - ) return types def generate_fpx_quantization_types(bits: int): types = {} - for ebits in range(1, bits): + for ebits in range(0, bits): mbits = bits - ebits - 1 types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) - types[f"fp{bits}_e{ebits}m{mbits}_a16w{bits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) non_sign_bits = bits - 1 default_ebits = (non_sign_bits + 1) // 2 @@ -550,20 +536,17 @@ def generate_fpx_quantization_types(bits: int): return types - # TODO(aryan): handle cuda capability and torch 2.2/2.3 + # TODO(aryan): handle torch 2.2/2.3 FLOATX_QUANTIZATION_TYPES = { # float8_e5m2 weight + bfloat16/float16 activation "float8": float8_weight_only, "float8_weight_only": float8_weight_only, "float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), - "float8_a16w8": float8_weight_only, "float8_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), "float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), - "float8_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), # float8_e4m3 weight + bfloat16/float16 activation "float8_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), "float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), - "float8wo_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), # float8_e5m2 weight + float8 activation (dynamic) "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, "float8dq": float8_dynamic_activation_float8_weight, @@ -572,7 +555,6 @@ def generate_fpx_quantization_types(bits: int): activation_dtype=torch.float8_e5m2, weight_dtype=torch.float8_e5m2, ), - "float8_a8w8": float8_dynamic_activation_float8_weight, **generate_float8dq_types(torch.float8_e5m2), # float8_e4m3 weight + float8 activation (dynamic) "float8dq_e4m3": partial( @@ -609,7 +591,6 @@ def generate_uintx_quantization_types(bits: int): types = {} types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) - types[f"uint{bits}_a16w{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) return types UINTX_QUANTIZATION_DTYPES = { @@ -625,13 +606,41 @@ def generate_uintx_quantization_types(bits: int): **generate_uintx_quantization_types(8), } + SHORTHAND_QUANTIZATION_TYPES = { + "int_a16w4": int4_weight_only, + "int_a8w4": int8_dynamic_activation_int4_weight, + "int_a16w8": int8_weight_only, + "int_a8w8": int8_dynamic_activation_int8_weight, + "uint_a16w1": partial(uintx_weight_only, dtype=torch.uint1), + "uint_a16w2": partial(uintx_weight_only, dtype=torch.uint2), + "uint_a16w3": partial(uintx_weight_only, dtype=torch.uint3), + "uint_a16w4": partial(uintx_weight_only, dtype=torch.uint4), + "uint_a16w5": partial(uintx_weight_only, dtype=torch.uint5), + "uint_a16w6": partial(uintx_weight_only, dtype=torch.uint6), + "uint_a16w7": partial(uintx_weight_only, dtype=torch.uint7), + "uint_a16w8": partial(uintx_weight_only, dtype=torch.uint8), + } + SHORTHAND_FLOAT_QUANTIZATION_TYPES = { + "float_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + "float_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), + "float_a8w8": float8_dynamic_activation_float8_weight, + "float_a16w3": partial(fpx_weight_only, ebits=2, mbits=0), + "float_a16w4": partial(fpx_weight_only, ebits=2, mbits=1), + "float_a16w5": partial(fpx_weight_only, ebits=3, mbits=1), + "float_a16w6": partial(fpx_weight_only, ebits=3, mbits=2), + "float_a16w7": partial(fpx_weight_only, ebits=4, mbits=2), + "float_a16w8": partial(fpx_weight_only, ebits=5, mbits=2), + } + QUANTIZATION_TYPES = {} QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) + QUANTIZATION_TYPES.update(SHORTHAND_QUANTIZATION_TYPES) if cls._is_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) + QUANTIZATION_TYPES.update(SHORTHAND_FLOAT_QUANTIZATION_TYPES) return QUANTIZATION_TYPES else: From bc006f29bf229215515f918ffadadd73f8433de6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 25 Nov 2024 00:05:53 +0100 Subject: [PATCH 07/27] fix --- src/diffusers/quantizers/quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 07f747819c70..beb89b9a70da 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -525,7 +525,7 @@ def generate_float8dq_types(dtype: torch.dtype): def generate_fpx_quantization_types(bits: int): types = {} - for ebits in range(0, bits): + for ebits in range(1, bits): mbits = bits - ebits - 1 types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) From 956f3bfddba6df1e39729218878b244fb0679d53 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 25 Nov 2024 02:19:31 +0100 Subject: [PATCH 08/27] fix --- .../quantizers/quantization_config.py | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index beb89b9a70da..493fcfb929be 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -503,22 +503,15 @@ def generate_float8dq_types(dtype: torch.dtype): name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3" types = {} - for activation_granularity_cls in [PerTensor, PerRow]: - for weight_granularity_cls in [PerTensor, PerRow]: - activation_name = "t" if activation_granularity_cls is PerTensor else "r" - weight_name = "t" if weight_granularity_cls is PerTensor else "r" - # The a{activation_name}w{weight_name} is a made up name for convenience of testing things. - # It suffixes with for different granularities (activation granularity, weight granularity): - # - atwt: PerTensor(), PerTensor() - # - atwr: PerTensor(), PerRow() - # - arwt: PerRow(), PerTensor() - # - arwr: PerRow(), PerRow() - types[f"float8dq_{name}_a{activation_name}w{weight_name}"] = partial( - float8_dynamic_activation_float8_weight, - activation_dtype=dtype, - weight_dtype=dtype, - granularity=(activation_granularity_cls(), weight_granularity_cls()), - ) + for granularity_cls in [PerTensor, PerRow]: + # Note: Activation and Weights cannot have different granularities + granularity_name = "tensor" if granularity_cls is PerTensor else "row" + types[f"float8dq_{name}_{granularity_name}"] = partial( + float8_dynamic_activation_float8_weight, + activation_dtype=dtype, + weight_dtype=dtype, + granularity=(granularity_cls(), granularity_cls()), + ) return types @@ -550,12 +543,15 @@ def generate_fpx_quantization_types(bits: int): # float8_e5m2 weight + float8 activation (dynamic) "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, "float8dq": float8_dynamic_activation_float8_weight, - "float8dq_e5m2": partial( - float8_dynamic_activation_float8_weight, - activation_dtype=torch.float8_e5m2, - weight_dtype=torch.float8_e5m2, - ), - **generate_float8dq_types(torch.float8_e5m2), + # ===== Matrix multiplication is not supported in float8_e5m2 so the following error out. + # However, changing activation_dtype=torch.float8_e4m3 might work here ===== + # "float8dq_e5m2": partial( + # float8_dynamic_activation_float8_weight, + # activation_dtype=torch.float8_e5m2, + # weight_dtype=torch.float8_e5m2, + # ), + # **generate_float8dq_types(torch.float8_e5m2), + # ===== ===== # float8_e4m3 weight + float8 activation (dynamic) "float8dq_e4m3": partial( float8_dynamic_activation_float8_weight, @@ -573,7 +569,8 @@ def generate_fpx_quantization_types(bits: int): **generate_fpx_quantization_types(5), **generate_fpx_quantization_types(6), **generate_fpx_quantization_types(7), - **generate_fpx_quantization_types(8), + # ===== Errors out with "torch.cat(): expected a non-empty list of Tensors" ===== + # **generate_fpx_quantization_types(8), } UINTX_TO_DTYPE = { @@ -584,7 +581,7 @@ def generate_fpx_quantization_types(bits: int): 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, - 8: torch.uint8, + # 8: torch.uint8, # uint8 quantization is not supported } def generate_uintx_quantization_types(bits: int): From cfdb94fd14f718ff6a0fe94b9c19eeac0dbecf17 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 25 Nov 2024 12:38:09 +0100 Subject: [PATCH 09/27] update --- docs/source/en/_toctree.yml | 2 + .../quantizers/quantization_config.py | 8 +- src/diffusers/utils/testing_utils.py | 13 + tests/quantization/torchao/README.md | 47 ++ tests/quantization/torchao/test_torchao.py | 444 ++++++++++++++++++ 5 files changed, 510 insertions(+), 4 deletions(-) create mode 100644 tests/quantization/torchao/README.md create mode 100644 tests/quantization/torchao/test_torchao.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2faabfec30ce..9f3ffa21f506 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -157,6 +157,8 @@ title: Getting Started - local: quantization/bitsandbytes title: bitsandbytes + - local: quantization/torchao + title: torchao title: Quantization Methods - sections: - local: optimization/fp16 diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 493fcfb929be..27c54b413fb8 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -452,7 +452,7 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = Non if len(unsupported_kwargs) > 0: raise ValueError( - f'The quantization method "{method}" does not supported the following keyword arguments: ' + f'The quantization method "{quant_type}" does not support the following keyword arguments: ' f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." ) @@ -581,7 +581,7 @@ def generate_fpx_quantization_types(bits: int): 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, - # 8: torch.uint8, # uint8 quantization is not supported + # 8: torch.uint8, } def generate_uintx_quantization_types(bits: int): @@ -600,7 +600,7 @@ def generate_uintx_quantization_types(bits: int): **generate_uintx_quantization_types(5), **generate_uintx_quantization_types(6), **generate_uintx_quantization_types(7), - **generate_uintx_quantization_types(8), + # **generate_uintx_quantization_types(8), # uint8 quantization is not supported } SHORTHAND_QUANTIZATION_TYPES = { @@ -615,7 +615,7 @@ def generate_uintx_quantization_types(bits: int): "uint_a16w5": partial(uintx_weight_only, dtype=torch.uint5), "uint_a16w6": partial(uintx_weight_only, dtype=torch.uint6), "uint_a16w7": partial(uintx_weight_only, dtype=torch.uint7), - "uint_a16w8": partial(uintx_weight_only, dtype=torch.uint8), + # "uint_a16w8": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported } SHORTHAND_FLOAT_QUANTIZATION_TYPES = { "float_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index b3e381f7d3fb..b4d3415de50e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -39,6 +39,7 @@ is_timm_available, is_torch_available, is_torch_version, + is_torchao_available, is_torchsde_available, is_transformers_available, ) @@ -476,6 +477,18 @@ def decorator(test_case): return decorator +def require_torchao_version_greater(torchao_version): + def decorator(test_case): + correct_torchao_version = is_torchao_available() and version.parse( + version.parse(importlib.metadata.version("torchao")).base_version + ) > version.parse(torchao_version) + return unittest.skipUnless( + correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}." + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend diff --git a/tests/quantization/torchao/README.md b/tests/quantization/torchao/README.md new file mode 100644 index 000000000000..277eb1fbfb5d --- /dev/null +++ b/tests/quantization/torchao/README.md @@ -0,0 +1,47 @@ +The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/tests/quantization/torchao_integration/). + +They were conducted on a single H100. Below is `nvidia-smi`: + +```bash ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.2 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:53:00.0 Off | 0 | +| N/A 34C P0 69W / 700W | 2MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| +| No running processes found | ++---------------------------------------------------------------------------------------+ +``` + +The benchmark results for Flux and CogVideoX can be found in [this](https://github.com/huggingface/diffusers/pull/10009) PR. + +`diffusers-cli`: + +```bash +- 🤗 Diffusers version: 0.32.0.dev0 +- Platform: Linux-5.15.0-1049-aws-x86_64-with-glibc2.31 +- Running on Google Colab?: No +- Python version: 3.10.14 +- PyTorch version (GPU?): 2.6.0.dev20241112+cu121 (False) +- Flax version (CPU?/GPU?/TPU?): not installed (NA) +- Jax version: not installed +- JaxLib version: not installed +- Huggingface_hub version: 0.26.2 +- Transformers version: 4.46.3 +- Accelerate version: 1.1.1 +- PEFT version: not installed +- Bitsandbytes version: not installed +- Safetensors version: 0.4.5 +- xFormers version: not installed +``` diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py new file mode 100644 index 000000000000..9e51393b8cec --- /dev/null +++ b/tests/quantization/torchao/test_torchao.py @@ -0,0 +1,444 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest +from typing import List + +import numpy as np +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, + TorchAoConfig, +) +from diffusers.utils.testing_utils import ( + is_torch_available, + is_torchao_available, + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + require_torchao_version_greater, + torch_device, +) + + +if is_torch_available(): + import torch + +if is_torchao_available(): + from torchao.dtypes import AffineQuantizedTensor + from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType + + +def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024): + weight = qlayer.weight + test_module.assertTrue(isinstance(weight, AffineQuantizedTensor)) + test_module.assertEqual(weight.quant_min, 0) + test_module.assertEqual(weight.quant_max, 15) + test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) + + +def check_forward(test_module, model, batch_size=1, context_size=1024): + # Test forward pass + with torch.no_grad(): + out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits + test_module.assertEqual(out.shape[0], batch_size) + test_module.assertEqual(out.shape[1], context_size) + + +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +# @slow +class TorchAoConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Makes sure the config format is properly set + """ + quantization_config = TorchAoConfig("int4_weight_only") + torchao_orig_config = quantization_config.to_dict() + + for key in torchao_orig_config: + self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key]) + + def test_post_init_check(self): + """ + Test kwargs validations in TorchAoConfig + """ + _ = TorchAoConfig("int4_weight_only") + with self.assertRaisesRegex(ValueError, "is not supported yet"): + _ = TorchAoConfig("uint8") + + with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): + _ = TorchAoConfig("int4_weight_only", group_size1=32) + + def test_repr(self): + """ + Check that there is no error in the repr + """ + quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8) + repr(quantization_config) + + +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +# @slow +class TorchAoTest(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def get_dummy_components(self, quantization_config: TorchAoConfig): + torch.manual_seed(0) + # TODO(aryan): push dummy model to hub + transformer = FluxTransformer2DModel.from_pretrained( + "./dummy-flux", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 32, + "width": 32, + "num_inference_steps": 2, + "output_type": "np", + "generator": generator, + } + + return inputs + + def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]): + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device, dtype=torch.bfloat16) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + output_slice = output[-1, -1, -3:, -3:].flatten() + + self.assertFalse(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_quantization(self): + QUANTIZATION_TYPES_TO_TEST = [ + ("int4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("int4dq", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("int8wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("int8dq", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("uint4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("int_a8w8", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("uint_a16w7", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ] + + if TorchAoConfig._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES_TO_TEST.extend( + [ + ("float8wo_e5m2", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("float8wo_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("fp4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("fp6", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ] + ) + + for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quantization_config = TorchAoConfig(quant_type=quantization_name) + self._test_quant_type(quantization_config, expected_slice) + + @unittest.skip("TODO(aryan): This test is not yet implemented.") + def test_int4wo_quant_bfloat16_conversion(self): + pass + # """ + # Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization + # """ + # quant_config = TorchAoConfig("int4_weight_only", group_size=32) + + # # Note: we quantize the bfloat16 model on the fly to int4 + # quantized_model = AutoModelForCausalLM.from_pretrained( + # self.model_name, + # torch_dtype=None, + # device_map=torch_device, + # quantization_config=quant_config, + # ) + # tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + # check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) + + # input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + # output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + # self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + @unittest.skip("TODO(aryan): This test is not yet implemented.") + def test_int4wo_quant_multi_gpu(self): + pass + # """ + # Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs + # set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS + # """ + + # quant_config = TorchAoConfig("int4_weight_only", group_size=32) + # quantized_model = AutoModelForCausalLM.from_pretrained( + # self.model_name, + # torch_dtype=torch.bfloat16, + # device_map="auto", + # quantization_config=quant_config, + # ) + # tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + # self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + # input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + # output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + # self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @unittest.skip("TODO(aryan): This test is not yet implemented.") + def test_int4wo_offload(self): + pass + # """ + # Simple test that checks if the quantized model int4 wieght only is working properly with cpu/disk offload + # """ + + # device_map_offload = { + # "model.embed_tokens": 0, + # "model.layers.0": 0, + # "model.layers.1": 0, + # "model.layers.2": 0, + # "model.layers.3": 0, + # "model.layers.4": 0, + # "model.layers.5": 0, + # "model.layers.6": 0, + # "model.layers.7": 0, + # "model.layers.8": 0, + # "model.layers.9": 0, + # "model.layers.10": 0, + # "model.layers.11": 0, + # "model.layers.12": 0, + # "model.layers.13": 0, + # "model.layers.14": 0, + # "model.layers.15": 0, + # "model.layers.16": 0, + # "model.layers.17": 0, + # "model.layers.18": 0, + # "model.layers.19": "cpu", + # "model.layers.20": "cpu", + # "model.layers.21": "disk", + # "model.norm": 0, + # "model.rotary_emb": 0, + # "lm_head": 0, + # } + + # quant_config = TorchAoConfig("int4_weight_only", group_size=32) + + # quantized_model = AutoModelForCausalLM.from_pretrained( + # self.model_name, + # torch_dtype=torch.bfloat16, + # device_map=device_map_offload, + # quantization_config=quant_config, + # ) + # tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + # input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + # output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + # EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside" + + # self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + + @unittest.skip("TODO(aryan): This test is not yet implemented.") + def test_int8_dynamic_activation_int8_weight_quant(self): + pass + # """ + # Simple LLM model testing int8_dynamic_activation_int8_weight + # """ + # quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight") + + # # Note: we quantize the bfloat16 model on the fly to int4 + # quantized_model = AutoModelForCausalLM.from_pretrained( + # self.model_name, + # device_map=torch_device, + # quantization_config=quant_config, + # ) + # tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + # input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + # output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + # EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + # self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + + +# @require_torch_gpu +# @require_torchao +# class TorchAoSerializationTest(unittest.TestCase): +# input_text = "What are we having for dinner?" +# max_new_tokens = 10 +# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" +# # TODO: investigate why we don't have the same output as the original model for this test +# SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" +# model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +# quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32} +# device = "cuda:0" + +# # called only once for all test in this class +# @classmethod +# def setUpClass(cls): +# cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs) +# cls.quantized_model = AutoModelForCausalLM.from_pretrained( +# cls.model_name, +# torch_dtype=torch.bfloat16, +# device_map=cls.device, +# quantization_config=cls.quant_config, +# ) +# cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + +# def tearDown(self): +# gc.collect() +# torch.cuda.empty_cache() +# gc.collect() + +# def test_original_model_expected_output(self): +# input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) +# output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + +# self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT) + +# def check_serialization_expected_output(self, device, expected_output): +# """ +# Test if we can serialize and load/infer the model again on the same device +# """ +# with tempfile.TemporaryDirectory() as tmpdirname: +# self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) +# loaded_quantized_model = AutoModelForCausalLM.from_pretrained( +# self.model_name, torch_dtype=torch.bfloat16, device_map=self.device +# ) +# input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) + +# output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) +# self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output) + +# def test_serialization_expected_output(self): +# self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT) + + +# class TorchAoSerializationW8A8Test(TorchAoSerializationTest): +# quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} +# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" +# SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT +# device = "cuda:0" + + +# class TorchAoSerializationW8Test(TorchAoSerializationTest): +# quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} +# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" +# SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT +# device = "cuda:0" + + +# class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): +# quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} +# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" +# SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT +# device = "cpu" + +# def test_serialization_expected_output_cuda(self): +# """ +# Test if we can serialize on device (cpu) and load/infer the model on cuda +# """ +# new_device = "cuda:0" +# self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT) + + +# class TorchAoSerializationW8CPUTest(TorchAoSerializationTest): +# quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} +# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" +# SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT +# device = "cpu" + +# def test_serialization_expected_output_cuda(self): +# """ +# Test if we can serialize on device (cpu) and load/infer the model on cuda +# """ +# new_device = "cuda:0" +# self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT) + + +if __name__ == "__main__": + unittest.main() From 01b2b420697b895d42cc8e159305d50828609182 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 28 Nov 2024 06:06:34 +0100 Subject: [PATCH 10/27] update tests --- tests/quantization/torchao/test_torchao.py | 264 +++++++-------------- 1 file changed, 89 insertions(+), 175 deletions(-) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 9e51393b8cec..8bb0bac9d119 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -14,11 +14,12 @@ # limitations under the License. import gc +import tempfile import unittest from typing import List import numpy as np -from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import ( AutoencoderKL, @@ -32,8 +33,8 @@ is_torchao_available, require_torch, require_torch_gpu, - require_torch_multi_gpu, require_torchao_version_greater, + slow, torch_device, ) @@ -46,14 +47,6 @@ from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType -def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024): - weight = qlayer.weight - test_module.assertTrue(isinstance(weight, AffineQuantizedTensor)) - test_module.assertEqual(weight.quant_min, 0) - test_module.assertEqual(weight.quant_max, 15) - test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) - - def check_forward(test_module, model, batch_size=1, context_size=1024): # Test forward pass with torch.no_grad(): @@ -65,7 +58,7 @@ def check_forward(test_module, model, batch_size=1, context_size=1024): @require_torch @require_torch_gpu @require_torchao_version_greater("0.6.0") -# @slow +@slow class TorchAoConfigTest(unittest.TestCase): def test_to_dict(self): """ @@ -99,7 +92,7 @@ def test_repr(self): @require_torch @require_torch_gpu @require_torchao_version_greater("0.6.0") -# @slow +@slow class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() @@ -107,51 +100,18 @@ def tearDown(self): gc.collect() def get_dummy_components(self, quantization_config: TorchAoConfig): - torch.manual_seed(0) - # TODO(aryan): push dummy model to hub + model_id = "hf-internal-testing/tiny-flux-pipe" transformer = FluxTransformer2DModel.from_pretrained( - "./dummy-flux", + model_id, + subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, ) - clip_text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - hidden_act="gelu", - projection_dim=32, - ) - - torch.manual_seed(0) - text_encoder = CLIPTextModel(clip_text_encoder_config) - - torch.manual_seed(0) - text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") - - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - - torch.manual_seed(0) - vae = AutoencoderKL( - sample_size=32, - in_channels=3, - out_channels=3, - block_out_channels=(4,), - layers_per_block=1, - latent_channels=1, - norm_num_groups=1, - use_quant_conv=False, - use_post_quant_conv=False, - shift_factor=0.0609, - scaling_factor=1.5035, - ) - + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") + text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -190,9 +150,10 @@ def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: L output = pipe(**inputs)[0] output_slice = output[-1, -1, -3:, -3:].flatten() - self.assertFalse(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): + # TODO(aryan): update these values from our CI QUANTIZATION_TYPES_TO_TEST = [ ("int4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), ("int4dq", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), @@ -220,129 +181,82 @@ def test_quantization(self): quantization_config = TorchAoConfig(quant_type=quantization_name) self._test_quant_type(quantization_config, expected_slice) - @unittest.skip("TODO(aryan): This test is not yet implemented.") def test_int4wo_quant_bfloat16_conversion(self): - pass - # """ - # Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization - # """ - # quant_config = TorchAoConfig("int4_weight_only", group_size=32) - - # # Note: we quantize the bfloat16 model on the fly to int4 - # quantized_model = AutoModelForCausalLM.from_pretrained( - # self.model_name, - # torch_dtype=None, - # device_map=torch_device, - # quantization_config=quant_config, - # ) - # tokenizer = AutoTokenizer.from_pretrained(self.model_name) - - # check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) - - # input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) - - # output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - # self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - - @require_torch_multi_gpu - @unittest.skip("TODO(aryan): This test is not yet implemented.") - def test_int4wo_quant_multi_gpu(self): - pass - # """ - # Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs - # set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS - # """ - - # quant_config = TorchAoConfig("int4_weight_only", group_size=32) - # quantized_model = AutoModelForCausalLM.from_pretrained( - # self.model_name, - # torch_dtype=torch.bfloat16, - # device_map="auto", - # quantization_config=quant_config, - # ) - # tokenizer = AutoTokenizer.from_pretrained(self.model_name) - - # self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) - - # input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) - - # output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - # self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - - @unittest.skip("TODO(aryan): This test is not yet implemented.") + """ + Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. + """ + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + qlayer = quantized_model.transformer_blocks[0].attn.to_q + weight = qlayer.weight + self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + self.assertEqual(weight.quant_min, 0) + self.assertEqual(weight.quant_max, 15) + self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) + def test_int4wo_offload(self): - pass - # """ - # Simple test that checks if the quantized model int4 wieght only is working properly with cpu/disk offload - # """ - - # device_map_offload = { - # "model.embed_tokens": 0, - # "model.layers.0": 0, - # "model.layers.1": 0, - # "model.layers.2": 0, - # "model.layers.3": 0, - # "model.layers.4": 0, - # "model.layers.5": 0, - # "model.layers.6": 0, - # "model.layers.7": 0, - # "model.layers.8": 0, - # "model.layers.9": 0, - # "model.layers.10": 0, - # "model.layers.11": 0, - # "model.layers.12": 0, - # "model.layers.13": 0, - # "model.layers.14": 0, - # "model.layers.15": 0, - # "model.layers.16": 0, - # "model.layers.17": 0, - # "model.layers.18": 0, - # "model.layers.19": "cpu", - # "model.layers.20": "cpu", - # "model.layers.21": "disk", - # "model.norm": 0, - # "model.rotary_emb": 0, - # "lm_head": 0, - # } - - # quant_config = TorchAoConfig("int4_weight_only", group_size=32) - - # quantized_model = AutoModelForCausalLM.from_pretrained( - # self.model_name, - # torch_dtype=torch.bfloat16, - # device_map=device_map_offload, - # quantization_config=quant_config, - # ) - # tokenizer = AutoTokenizer.from_pretrained(self.model_name) - - # input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) - - # output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - # EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside" - - # self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) - - @unittest.skip("TODO(aryan): This test is not yet implemented.") - def test_int8_dynamic_activation_int8_weight_quant(self): - pass - # """ - # Simple LLM model testing int8_dynamic_activation_int8_weight - # """ - # quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight") - - # # Note: we quantize the bfloat16 model on the fly to int4 - # quantized_model = AutoModelForCausalLM.from_pretrained( - # self.model_name, - # device_map=torch_device, - # quantization_config=quant_config, - # ) - # tokenizer = AutoTokenizer.from_pretrained(self.model_name) - - # input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) - - # output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - # EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - # self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + """ + Test if the quantized model int4 weight-only is working properly with cpu/disk offload. + """ + + device_map_offload = { + "time_text_embed": torch_device, + "context_embedder": torch_device, + "x_embedder": torch_device, + "transformer_blocks.0": "cpu", + "single_transformer_blocks.0": "disk", + "norm_out": torch_device, + "proj_out": "cpu", + } + + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to( + torch_device, dtype=torch.bfloat16 + ) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + torch_device, dtype=torch.bfloat16 + ) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device, dtype=torch.bfloat16) + text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device, dtype=torch.bfloat16) + image_ids = torch.randn((height * width, num_image_channels)).to(torch_device, dtype=torch.bfloat16) + timestep = torch.tensor([1.0]).to(torch_device, dtype=torch.bfloat16).expand(batch_size) + + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map_offload, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + output = quantized_model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + img_ids=image_ids, + txt_ids=text_ids, + pooled_projections=pooled_prompt_embeds, + timestep=timestep, + ) + + output_slice = output.flatten()[-9:].detach().cpu().numpy() + # TODO(aryan): get slice from CI + expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) # @require_torch_gpu From b17cf35bb430106e1aeccfd294e565f15b931683 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 28 Nov 2024 06:08:15 +0100 Subject: [PATCH 11/27] device map changes --- src/diffusers/models/modeling_utils.py | 4 ---- src/diffusers/quantizers/torchao/torchao_quantizer.py | 6 +++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5c475c748e4a..3b5cbfbf968e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -671,10 +671,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer = None if hf_quantizer is not None: - if device_map is not None: - raise NotImplementedError( - "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." - ) hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 10522eaa390a..ca990dd6d982 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -35,6 +35,9 @@ import torch import torch.nn as nn +if is_torchao_available(): + from torchao.quantization import quantize_ + logger = logging.get_logger(__name__) @@ -131,7 +134,6 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": torch.uint5, torch.uint6, torch.uint7, - torch.uint8, ) if isinstance(target_dtype, supported_dtypes): return target_dtype @@ -181,8 +183,6 @@ def create_quantized_param( Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module. """ - from torchao.quantization import quantize_ - module, tensor_name = get_module_from_name(model, param_name) if self.pre_quantized: From 250ccf4d99ea9981f552e2cfa2ffc27e83d7b829 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 15:37:22 +0100 Subject: [PATCH 12/27] update; apply suggestions from review --- docs/source/en/quantization/torchao.md | 27 +- .../quantizers/quantization_config.py | 40 +- .../quantizers/torchao/torchao_quantizer.py | 19 +- tests/quantization/torchao/test_torchao.py | 341 +++++++++++------- 4 files changed, 263 insertions(+), 164 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index d10fb4b33cf2..335b7cfacb18 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -23,9 +23,32 @@ pip install -U torch torchao Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. -## Usage +```python +from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig + +model_id = "black-forest-labs/Flux.1-Dev" +dtype = torch.bfloat16 + +quantization_config = TorchAoConfig("int8wo") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=dtype, +) +pipe = FluxPipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=dtype, +) +pipe.to("cuda") + +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] +image.save("output.png") +``` ## Resources -- [TorchAO Quantization API]() +- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) - [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 27c54b413fb8..a98a2fac1572 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -412,18 +412,20 @@ class TorchAoConfig(QuantizationConfigMixin): https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques Example: - - ```python - TODO(aryan): update - quantization_config = TorchAoConfig("int4_weight_only", group_size=32) - # int4_weight_only quant is only working with *torch.bfloat16* dtype right now - model = AutoModelForCausalLM.from_pretrained( - model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config - ) - ``` + ```python + from diffusers import FluxTransformer2DModel, TorchAoConfig + + quantization_config = TorchAoConfig("int8wo") + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + ``` """ - def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs): + def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs) -> None: self.quant_method = QuantizationMethod.TORCHAO self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert @@ -481,7 +483,6 @@ def _get_torchao_quant_type_to_method(cls): INT4_QUANTIZATION_TYPES = { # int4 weight + bfloat16/float16 activation - "int4": int4_weight_only, "int4wo": int4_weight_only, "int4_weight_only": int4_weight_only, # int4 weight + int8 activation @@ -491,7 +492,6 @@ def _get_torchao_quant_type_to_method(cls): INT8_QUANTIZATION_TYPES = { # int8 weight + bfloat16/float16 activation - "int8": int8_weight_only, "int8wo": int8_weight_only, "int8_weight_only": int8_weight_only, # int8 weight + int8 activation @@ -532,17 +532,14 @@ def generate_fpx_quantization_types(bits: int): # TODO(aryan): handle torch 2.2/2.3 FLOATX_QUANTIZATION_TYPES = { # float8_e5m2 weight + bfloat16/float16 activation - "float8": float8_weight_only, - "float8_weight_only": float8_weight_only, "float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), - "float8_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + "float8_weight_only": float8_weight_only, "float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), # float8_e4m3 weight + bfloat16/float16 activation - "float8_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), "float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), # float8_e5m2 weight + float8 activation (dynamic) - "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, "float8dq": float8_dynamic_activation_float8_weight, + "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, # ===== Matrix multiplication is not supported in float8_e5m2 so the following error out. # However, changing activation_dtype=torch.float8_e4m3 might work here ===== # "float8dq_e5m2": partial( @@ -581,7 +578,7 @@ def generate_fpx_quantization_types(bits: int): 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, - # 8: torch.uint8, + # 8: torch.uint8, # uint8 quantization is not supported } def generate_uintx_quantization_types(bits: int): @@ -626,7 +623,7 @@ def generate_uintx_quantization_types(bits: int): "float_a16w5": partial(fpx_weight_only, ebits=3, mbits=1), "float_a16w6": partial(fpx_weight_only, ebits=3, mbits=2), "float_a16w7": partial(fpx_weight_only, ebits=4, mbits=2), - "float_a16w8": partial(fpx_weight_only, ebits=5, mbits=2), + "float_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), } QUANTIZATION_TYPES = {} @@ -647,6 +644,11 @@ def generate_uintx_quantization_types(bits: int): @staticmethod def _is_cuda_capability_atleast_8_9() -> bool: + if not torch.cuda.is_available(): + if torch.mps.is_available(): + return False + raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.") + major, minor = torch.cuda.get_device_capability() if major == 8: return minor >= 9 diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index ca990dd6d982..f4d56c00ca26 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -96,7 +96,7 @@ def validate_environment(self, *args, **kwargs): if weights_only: torch_version = version.parse(importlib.metadata.version("torch")) if torch_version < version.parse("2.5.0"): - # TODO(aryan): TorchAO is compatible with Pytorch 2.2 for certain quantization types. Try to see if we can support it + # TODO(aryan): TorchAO is compatible with Pytorch >= 2.2 for certain quantization types. Try to see if we can support it in future raise RuntimeError( f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}." ) @@ -112,7 +112,7 @@ def update_torch_dtype(self, torch_dtype): ) if torch_dtype is None: - # we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op + # We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op logger.warning( "Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` " "to enable model loading in different precisions. Pass your own `torch_dtype` to specify the " @@ -122,8 +122,11 @@ def update_torch_dtype(self, torch_dtype): return torch_dtype - def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + def adjust_target_dtype(self, target_dtype: torch.dtype) -> torch.dtype: supported_dtypes = ( + # At the moment, only int8 is supported for integer quantization dtypes. + # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future + # to support more quantization methods, such as intx_weight_only. torch.int8, torch.float8_e4m3fn, torch.float8_e5m2, @@ -138,6 +141,9 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": if isinstance(target_dtype, supported_dtypes): return target_dtype + # We need one of the supported dtypes to be selected in order for accelerate to determine + # the total size of modules/parameters for auto device placement. This method will not be + # called when device_map is not "auto". raise ValueError( f"You are using `device_map='auto'` on a TorchAO quantized model but a suitable target dtype " f"could not be inferred. The supported target_dtypes are: {supported_dtypes}. If you think the " @@ -145,8 +151,6 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": ) def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: - # TODO(aryan): Not sure what to do here @sayakpaul, @DN6 - # need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128 max_memory = {key: val * 0.9 for key, val in max_memory.items()} return max_memory @@ -159,14 +163,14 @@ def check_if_quantized_param( **kwargs, ) -> bool: param_device = kwargs.pop("param_device", None) - # check if the param_name is not in self.modules_to_not_convert + # Check if the param_name is not in self.modules_to_not_convert if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): return False elif param_device == "cpu" and self.offload: # We don't quantize weights that we offload return False else: - # we only quantize the weight of nn.Linear + # We only quantize the weight of nn.Linear module, tensor_name = get_module_from_name(model, param_name) return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") @@ -249,5 +253,4 @@ def is_serializable(self, safe_serialization=None): @property def is_trainable(self): - # TODO(aryan): needs testing return self.quantization_config.quant_type.startswith("int8") diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 8bb0bac9d119..fdafc7d2d74a 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -28,6 +28,7 @@ FluxTransformer2DModel, TorchAoConfig, ) +from diffusers.models.attention_processor import Attention from diffusers.utils.testing_utils import ( is_torch_available, is_torchao_available, @@ -41,20 +42,36 @@ if is_torch_available(): import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + if is_torchao_available(): from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType -def check_forward(test_module, model, batch_size=1, context_size=1024): - # Test forward pass - with torch.no_grad(): - out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits - test_module.assertEqual(out.shape[0], batch_size) - test_module.assertEqual(out.shape[1], context_size) - - @require_torch @require_torch_gpu @require_torchao_version_greater("0.6.0") @@ -97,7 +114,6 @@ class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() torch.cuda.empty_cache() - gc.collect() def get_dummy_components(self, quantization_config: TorchAoConfig): model_id = "hf-internal-testing/tiny-flux-pipe" @@ -141,6 +157,32 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0): return inputs + def get_dummy_tensor_inputs(self, device=None): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]): components = self.get_dummy_components(quantization_config) pipe = FluxPipeline(**components) @@ -193,14 +235,13 @@ def test_int4wo_quant_bfloat16_conversion(self): torch_dtype=torch.bfloat16, ) - qlayer = quantized_model.transformer_blocks[0].attn.to_q - weight = qlayer.weight + weight = quantized_model.transformer_blocks[0].ff.net[2].weight self.assertTrue(isinstance(weight, AffineQuantizedTensor)) self.assertEqual(weight.quant_min, 0) self.assertEqual(weight.quant_max, 15) self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) - def test_int4wo_offload(self): + def test_offload(self): """ Test if the quantized model int4 weight-only is working properly with cpu/disk offload. """ @@ -215,23 +256,7 @@ def test_int4wo_offload(self): "proj_out": "cpu", } - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - height = width = 4 - sequence_length = 48 - embedding_dim = 32 - - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to( - torch_device, dtype=torch.bfloat16 - ) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( - torch_device, dtype=torch.bfloat16 - ) - pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device, dtype=torch.bfloat16) - text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device, dtype=torch.bfloat16) - image_ids = torch.randn((height * width, num_image_channels)).to(torch_device, dtype=torch.bfloat16) - timestep = torch.tensor([1.0]).to(torch_device, dtype=torch.bfloat16).expand(batch_size) + inputs = self.get_dummy_tensor_inputs(torch_device) with tempfile.TemporaryDirectory() as offload_folder: quantization_config = TorchAoConfig("int4_weight_only", group_size=64) @@ -244,115 +269,161 @@ def test_int4wo_offload(self): offload_folder=offload_folder, ) - output = quantized_model( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - img_ids=image_ids, - txt_ids=text_ids, - pooled_projections=pooled_prompt_embeds, - timestep=timestep, - ) + output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().cpu().numpy() + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() # TODO(aryan): get slice from CI expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + def test_modules_to_not_convert(self): + quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2] + self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) + self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor)) + self.assertEquals(unquantized_layer.weight.dtype, torch.bfloat16) + + quantized_layer = quantized_model.proj_out + self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) + self.assertEquals(quantized_layer.weight.layout_tensor.data.dtype, torch.int8) + + def test_training(self): + quantization_config = TorchAoConfig("int8_weight_only") + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + for param in quantized_model.parameters(): + # freeze the model as only adapter layers will be trained + param.requires_grad = False + if param.ndim == 1: + param.data = param.data.to(torch.float32) + + for _, module in quantized_model.named_modules(): + if isinstance(module, Attention): + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output.norm().backward() + + for module in quantized_model.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +@slow +class TorchAoSerializationTest(unittest.TestCase): + model_name = "hf-internal-testing/tiny-flux-pipe" + quant_method, quant_method_kwargs = None, None + device = "cuda" + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_model(self, device=None): + quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs) + quantized_model = FluxTransformer2DModel.from_pretrained( + self.model_name, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + return quantized_model.to(device) + + def get_dummy_tensor_inputs(self, device=None): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + + def test_original_model_expected_slice(self): + quantized_model = self.get_dummy_model(torch_device) + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3)) + + def check_serialization_expected_slice(self, expected_slice): + quantized_model = self.get_dummy_model(self.device) + + with tempfile.TemporaryDirectory() as tmp_dir: + quantized_model.save_pretrained(tmp_dir, safe_serialization=False) + loaded_quantized_model = FluxTransformer2DModel.from_pretrained( + tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False + ) + + inputs = self.get_dummy_tensor_inputs(torch_device) + output = loaded_quantized_model(**inputs)[0] + + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_serialization_expected_slice(self): + self.check_serialization_expected_slice(self.serialized_expected_slice) + + +class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + serialized_expected_slice = expected_slice + device = "cuda" + + +class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + serialized_expected_slice = expected_slice + device = "cuda" + + +class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + serialized_expected_slice = expected_slice + device = "cpu" + -# @require_torch_gpu -# @require_torchao -# class TorchAoSerializationTest(unittest.TestCase): -# input_text = "What are we having for dinner?" -# max_new_tokens = 10 -# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" -# # TODO: investigate why we don't have the same output as the original model for this test -# SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" -# model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" -# quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32} -# device = "cuda:0" - -# # called only once for all test in this class -# @classmethod -# def setUpClass(cls): -# cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs) -# cls.quantized_model = AutoModelForCausalLM.from_pretrained( -# cls.model_name, -# torch_dtype=torch.bfloat16, -# device_map=cls.device, -# quantization_config=cls.quant_config, -# ) -# cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) - -# def tearDown(self): -# gc.collect() -# torch.cuda.empty_cache() -# gc.collect() - -# def test_original_model_expected_output(self): -# input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) -# output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - -# self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT) - -# def check_serialization_expected_output(self, device, expected_output): -# """ -# Test if we can serialize and load/infer the model again on the same device -# """ -# with tempfile.TemporaryDirectory() as tmpdirname: -# self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) -# loaded_quantized_model = AutoModelForCausalLM.from_pretrained( -# self.model_name, torch_dtype=torch.bfloat16, device_map=self.device -# ) -# input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) - -# output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) -# self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output) - -# def test_serialization_expected_output(self): -# self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT) - - -# class TorchAoSerializationW8A8Test(TorchAoSerializationTest): -# quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} -# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" -# SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT -# device = "cuda:0" - - -# class TorchAoSerializationW8Test(TorchAoSerializationTest): -# quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} -# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" -# SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT -# device = "cuda:0" - - -# class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): -# quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} -# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" -# SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT -# device = "cpu" - -# def test_serialization_expected_output_cuda(self): -# """ -# Test if we can serialize on device (cpu) and load/infer the model on cuda -# """ -# new_device = "cuda:0" -# self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT) - - -# class TorchAoSerializationW8CPUTest(TorchAoSerializationTest): -# quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} -# ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" -# SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT -# device = "cpu" - -# def test_serialization_expected_output_cuda(self): -# """ -# Test if we can serialize on device (cpu) and load/infer the model on cuda -# """ -# new_device = "cuda:0" -# self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT) - - -if __name__ == "__main__": - unittest.main() +class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + serialized_expected_slice = expected_slice + device = "cpu" From 50946a9c7f94e5502a421a22e48fb26f9d57041d Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 15:39:28 +0100 Subject: [PATCH 13/27] fix --- src/diffusers/quantizers/torchao/torchao_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index f4d56c00ca26..0e49d3771a77 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -122,7 +122,7 @@ def update_torch_dtype(self, torch_dtype): return torch_dtype - def adjust_target_dtype(self, target_dtype: torch.dtype) -> torch.dtype: + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": supported_dtypes = ( # At the moment, only int8 is supported for integer quantization dtypes. # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future From 8f09bdf14c2b9fddc70150aed101162118c22f89 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 15:41:29 +0100 Subject: [PATCH 14/27] remove slow marker --- tests/quantization/torchao/test_torchao.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index fdafc7d2d74a..2f29f255af84 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -35,7 +35,6 @@ require_torch, require_torch_gpu, require_torchao_version_greater, - slow, torch_device, ) @@ -75,7 +74,6 @@ def forward(self, input, *args, **kwargs): @require_torch @require_torch_gpu @require_torchao_version_greater("0.6.0") -@slow class TorchAoConfigTest(unittest.TestCase): def test_to_dict(self): """ @@ -109,7 +107,6 @@ def test_repr(self): @require_torch @require_torch_gpu @require_torchao_version_greater("0.6.0") -@slow class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() @@ -329,7 +326,6 @@ def test_training(self): @require_torch @require_torch_gpu @require_torchao_version_greater("0.6.0") -@slow class TorchAoSerializationTest(unittest.TestCase): model_name = "hf-internal-testing/tiny-flux-pipe" quant_method, quant_method_kwargs = None, None From 7c79b8e8786cc7a6a6619f4958f5fe840f37c3a4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 15:53:25 +0100 Subject: [PATCH 15/27] remove pytest deprecation warnings --- tests/quantization/torchao/test_torchao.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 2f29f255af84..fb4f4fa26b87 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -285,11 +285,11 @@ def test_modules_to_not_convert(self): unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2] self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor)) - self.assertEquals(unquantized_layer.weight.dtype, torch.bfloat16) + self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16) quantized_layer = quantized_model.proj_out self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) - self.assertEquals(quantized_layer.weight.layout_tensor.data.dtype, torch.int8) + self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8) def test_training(self): quantization_config = TorchAoConfig("int8_weight_only") From 747bd7d9bf8b5e386efb92f2ae6031bb93d6ab98 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 09:45:38 +0100 Subject: [PATCH 16/27] apply review suggestions --- docs/source/en/quantization/torchao.md | 2 + .../quantizers/quantization_config.py | 101 ++++++++++-------- .../quantizers/torchao/torchao_quantizer.py | 38 ++++--- 3 files changed, 79 insertions(+), 62 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 335b7cfacb18..78d69569a530 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -48,6 +48,8 @@ image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] image.save("output.png") ``` +Additionally, TorchAO supports an automatic quantization API exposed with [`autoquant`](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. This can directly be used with the underlying modeling components at the moment, but Diffusers will also expose an autoquant configuration option in the future. + ## Resources - [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index a98a2fac1572..7022e28b8db6 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -436,14 +436,14 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = Non else: self.quant_type_kwargs = kwargs - _STR_TO_METHOD = self._get_torchao_quant_type_to_method() - if self.quant_type not in _STR_TO_METHOD.keys(): + TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() + if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): raise ValueError( f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the " f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." ) - method = _STR_TO_METHOD[self.quant_type] + method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type] signature = inspect.signature(method) all_kwargs = { param.name @@ -481,24 +481,6 @@ def _get_torchao_quant_type_to_method(cls): # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers from torchao.quantization.observer import PerRow, PerTensor - INT4_QUANTIZATION_TYPES = { - # int4 weight + bfloat16/float16 activation - "int4wo": int4_weight_only, - "int4_weight_only": int4_weight_only, - # int4 weight + int8 activation - "int4dq": int8_dynamic_activation_int4_weight, - "int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight, - } - - INT8_QUANTIZATION_TYPES = { - # int8 weight + bfloat16/float16 activation - "int8wo": int8_weight_only, - "int8_weight_only": int8_weight_only, - # int8 weight + int8 activation - "int8dq": int8_dynamic_activation_int8_weight, - "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, - } - def generate_float8dq_types(dtype: torch.dtype): name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3" types = {} @@ -529,6 +511,41 @@ def generate_fpx_quantization_types(bits: int): return types + def generate_uintx_quantization_types(bits: int): + UINTX_TO_DTYPE = { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + # 8: torch.uint8, # uint8 quantization is not supported + } + + types = {} + types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) + types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) + return types + + INT4_QUANTIZATION_TYPES = { + # int4 weight + bfloat16/float16 activation + "int4wo": int4_weight_only, + "int4_weight_only": int4_weight_only, + # int4 weight + int8 activation + "int4dq": int8_dynamic_activation_int4_weight, + "int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight, + } + + INT8_QUANTIZATION_TYPES = { + # int8 weight + bfloat16/float16 activation + "int8wo": int8_weight_only, + "int8_weight_only": int8_weight_only, + # int8 weight + int8 activation + "int8dq": int8_dynamic_activation_int8_weight, + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + } + # TODO(aryan): handle torch 2.2/2.3 FLOATX_QUANTIZATION_TYPES = { # float8_e5m2 weight + bfloat16/float16 activation @@ -540,7 +557,7 @@ def generate_fpx_quantization_types(bits: int): # float8_e5m2 weight + float8 activation (dynamic) "float8dq": float8_dynamic_activation_float8_weight, "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, - # ===== Matrix multiplication is not supported in float8_e5m2 so the following error out. + # ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out. # However, changing activation_dtype=torch.float8_e4m3 might work here ===== # "float8dq_e5m2": partial( # float8_dynamic_activation_float8_weight, @@ -566,27 +583,8 @@ def generate_fpx_quantization_types(bits: int): **generate_fpx_quantization_types(5), **generate_fpx_quantization_types(6), **generate_fpx_quantization_types(7), - # ===== Errors out with "torch.cat(): expected a non-empty list of Tensors" ===== - # **generate_fpx_quantization_types(8), - } - - UINTX_TO_DTYPE = { - 1: torch.uint1, - 2: torch.uint2, - 3: torch.uint3, - 4: torch.uint4, - 5: torch.uint5, - 6: torch.uint6, - 7: torch.uint7, - # 8: torch.uint8, # uint8 quantization is not supported } - def generate_uintx_quantization_types(bits: int): - types = {} - types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) - types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) - return types - UINTX_QUANTIZATION_DTYPES = { "uintx": uintx_weight_only, "uintx_weight_only": uintx_weight_only, @@ -614,6 +612,7 @@ def generate_uintx_quantization_types(bits: int): "uint_a16w7": partial(uintx_weight_only, dtype=torch.uint7), # "uint_a16w8": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported } + SHORTHAND_FLOAT_QUANTIZATION_TYPES = { "float_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), "float_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), @@ -645,8 +644,6 @@ def generate_uintx_quantization_types(bits: int): @staticmethod def _is_cuda_capability_atleast_8_9() -> bool: if not torch.cuda.is_available(): - if torch.mps.is_available(): - return False raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.") major, minor = torch.cuda.get_device_capability() @@ -655,9 +652,23 @@ def _is_cuda_capability_atleast_8_9() -> bool: return major >= 9 def get_apply_tensor_subclass(self): - _STR_TO_METHOD = self._get_torchao_quant_type_to_method() - return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs) + TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() + return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs) def __repr__(self): + r""" + Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`: + + ``` + TorchAoConfig { + "modules_to_not_convert": null, + "quant_method": "torchao", + "quant_type": "uint_a16w4", + "quant_type_kwargs": { + "group_size": 32 + } + } + ``` + """ config_dict = self.to_dict() return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 0e49d3771a77..82791550806c 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -41,6 +41,22 @@ logger = logging.get_logger(__name__) +SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + # At the moment, only int8 is supported for integer quantization dtypes. + # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future + # to support more quantization methods, such as intx_weight_only. + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, +) + def _quantization_type(weight): from torchao.dtypes import AffineQuantizedTensor @@ -123,22 +139,7 @@ def update_torch_dtype(self, torch_dtype): return torch_dtype def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": - supported_dtypes = ( - # At the moment, only int8 is supported for integer quantization dtypes. - # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future - # to support more quantization methods, such as intx_weight_only. - torch.int8, - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.uint1, - torch.uint2, - torch.uint3, - torch.uint4, - torch.uint5, - torch.uint6, - torch.uint7, - ) - if isinstance(target_dtype, supported_dtypes): + if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION): return target_dtype # We need one of the supported dtypes to be selected in order for accelerate to determine @@ -146,7 +147,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": # called when device_map is not "auto". raise ValueError( f"You are using `device_map='auto'` on a TorchAO quantized model but a suitable target dtype " - f"could not be inferred. The supported target_dtypes are: {supported_dtypes}. If you think the " + f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the " f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." ) @@ -190,10 +191,13 @@ def create_quantized_param( module, tensor_name = get_module_from_name(model, param_name) if self.pre_quantized: + # If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info + # about AffineQuantizedTensor module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) if isinstance(module, nn.Linear): module.extra_repr = types.MethodType(_linear_extra_repr, module) else: + # As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) quantize_(module, self.quantization_config.get_apply_tensor_subclass()) From 25d3cf856a6dc2576de10d85337bb39d4de30dd2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 10:17:35 +0100 Subject: [PATCH 17/27] add torch compile test --- tests/quantization/torchao/test_torchao.py | 28 +++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index fb4f4fa26b87..fd290aa092f9 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -154,7 +154,7 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0): return inputs - def get_dummy_tensor_inputs(self, device=None): + def get_dummy_tensor_inputs(self, device=None, seed: int = 0): batch_size = 1 num_latent_channels = 4 num_image_channels = 3 @@ -162,13 +162,23 @@ def get_dummy_tensor_inputs(self, device=None): sequence_length = 48 embedding_dim = 32 + torch.manual_seed(seed) hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( device, dtype=torch.bfloat16 ) + + torch.manual_seed(seed) pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) return { @@ -322,6 +332,22 @@ def test_training(self): self.assertTrue(module.adapter[1].weight.grad is not None) self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + def test_torch_compile(self): + quantization_config = TorchAoConfig("int8_weight_only") + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device, dtype=torch.bfloat16) + + inputs = self.get_dummy_inputs(torch_device) + normal_output = pipe(**inputs)[0].flatten()[-32:] + + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) + inputs = self.get_dummy_inputs(torch_device) + compile_output = pipe(**inputs)[0].flatten()[-32:] + + # Note: Seems to require higher tolerance + self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) + @require_torch @require_torch_gpu From 10deb1699e3b87e8c687f83f354be6c5c00dd6a3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 13:45:24 +0100 Subject: [PATCH 18/27] add more tests; add expected slices --- tests/quantization/torchao/README.md | 8 +- tests/quantization/torchao/test_torchao.py | 224 ++++++++++++++++++--- 2 files changed, 200 insertions(+), 32 deletions(-) diff --git a/tests/quantization/torchao/README.md b/tests/quantization/torchao/README.md index 277eb1fbfb5d..fadc529e12fc 100644 --- a/tests/quantization/torchao/README.md +++ b/tests/quantization/torchao/README.md @@ -1,6 +1,6 @@ The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/tests/quantization/torchao_integration/). -They were conducted on a single H100. Below is `nvidia-smi`: +The benchmarks were run on a single H100. Below is `nvidia-smi`: ```bash +---------------------------------------------------------------------------------------+ @@ -26,6 +26,12 @@ They were conducted on a single H100. Below is `nvidia-smi`: The benchmark results for Flux and CogVideoX can be found in [this](https://github.com/huggingface/diffusers/pull/10009) PR. +The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent: + +```bash +HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests +``` + `diffusers-cli`: ```bash diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index fd290aa092f9..274330a8bf6b 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -30,15 +30,20 @@ ) from diffusers.models.attention_processor import Attention from diffusers.utils.testing_utils import ( + enable_full_determinism, is_torch_available, is_torchao_available, require_torch, require_torch_gpu, require_torchao_version_greater, + slow, torch_device, ) +enable_full_determinism() + + if is_torch_available(): import torch import torch.nn as nn @@ -101,9 +106,21 @@ def test_repr(self): Check that there is no error in the repr """ quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8) - repr(quantization_config) - - + expected_repr = """TorchAoConfig { + "modules_to_not_convert": [ + "conv" + ], + "quant_method": "torchao", + "quant_type": "int4_weight_only", + "quant_type_kwargs": { + "group_size": 8 + } + }""".replace(" ", "").replace("\n", "") + quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") + self.assertEqual(quantization_repr, expected_repr) + + +# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_gpu @require_torchao_version_greater("0.6.0") @@ -202,32 +219,44 @@ def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: L self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): - # TODO(aryan): update these values from our CI + # fmt: off QUANTIZATION_TYPES_TO_TEST = [ - ("int4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("int4dq", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("int8wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("int8dq", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("uint4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("int_a8w8", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("uint_a16w7", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), + ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), + ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), + ("int_a8w8", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint_a16w7", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ] if TorchAoConfig._is_cuda_capability_atleast_8_9(): - QUANTIZATION_TYPES_TO_TEST.extend( - [ - ("float8wo_e5m2", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("float8wo_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("fp4wo", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ("fp6", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - ] - ) + QUANTIZATION_TYPES_TO_TEST.extend([ + ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), + ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), + # ===== + # The following lead to an internal torch error: + # RuntimeError: mat2 shape (32x4 must be divisible by 16 + # Skip these for now; TODO(aryan): investigate later + # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + # Cutlass fails to initialize for below + # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ]) + # fmt: on for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: - quantization_config = TorchAoConfig(quant_type=quantization_name) + quant_kwargs = {} + if quantization_name in ["uint4wo", "uint_a16w7"]: + # The dummy flux model that we use requires us to impose some restrictions on group_size here + quant_kwargs.update({"group_size": 16}) + quantization_config = TorchAoConfig( + quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs + ) self._test_quant_type(quantization_config, expected_slice) def test_int4wo_quant_bfloat16_conversion(self): @@ -277,10 +306,9 @@ def test_offload(self): ) output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - # TODO(aryan): get slice from CI - expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + + expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_modules_to_not_convert(self): @@ -333,6 +361,7 @@ def test_training(self): self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) def test_torch_compile(self): + r"""Test that verifies if torch.compile works with torchao quantization.""" quantization_config = TorchAoConfig("int8_weight_only") components = self.get_dummy_components(quantization_config) pipe = FluxPipeline(**components) @@ -348,7 +377,54 @@ def test_torch_compile(self): # Note: Seems to require higher tolerance self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) + @staticmethod + def _get_memory_footprint(module): + quantized_param_memory = 0.0 + unquantized_param_memory = 0.0 + + for param in module.parameters(): + if param.__class__.__name__ == "AffineQuantizedTensor": + data, scale, zero_point = param.layout_tensor.get_plain() + quantized_param_memory += data.numel() + data.element_size() + quantized_param_memory += scale.numel() + scale.element_size() + quantized_param_memory += zero_point.numel() + zero_point.element_size() + else: + unquantized_param_memory += param.data.numel() * param.data.element_size() + + total_memory = quantized_param_memory + unquantized_param_memory + return total_memory, quantized_param_memory, unquantized_param_memory + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"] + transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"] + transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"] + transformer_bf16 = self.get_dummy_components(None)["transformer"] + + total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo) + total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint( + transformer_int4wo_gs32 + ) + total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo) + total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16) + self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16) + # int4wo_gs32 has smaller group size, so more groups -> more scales and zero points + self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32) + # int4 with default group size quantized very few linear layers compared to a smaller group size of 32 + self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32) + # int8 quantizes more layers compare to int4 with default group size + self.assertTrue(quantized_int8wo < quantized_int4wo) + + def test_wrong_config(self): + with self.assertRaises(ValueError): + self.get_dummy_components(TorchAoConfig("int42")) + + +# This class is not to be run as a test by itself. See the tests that follow this class @require_torch @require_torch_gpu @require_torchao_version_greater("0.6.0") @@ -371,7 +447,7 @@ def get_dummy_model(self, device=None): ) return quantized_model.to(device) - def get_dummy_tensor_inputs(self, device=None): + def get_dummy_tensor_inputs(self, device=None, seed: int = 0): batch_size = 1 num_latent_channels = 4 num_image_channels = 3 @@ -379,6 +455,7 @@ def get_dummy_tensor_inputs(self, device=None): sequence_length = 48 embedding_dim = 32 + torch.manual_seed(seed) hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( device, dtype=torch.bfloat16 @@ -425,27 +502,112 @@ def test_serialization_expected_slice(self): class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest): quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} - expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) serialized_expected_slice = expected_slice device = "cuda" class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest): quant_method, quant_method_kwargs = "int8_weight_only", {} - expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) serialized_expected_slice = expected_slice device = "cuda" class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest): quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} - expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) serialized_expected_slice = expected_slice device = "cpu" class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): quant_method, quant_method_kwargs = "int8_weight_only", {} - expected_slice = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]) + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) serialized_expected_slice = expected_slice device = "cpu" + + +# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +@slow +class SlowTorchAoTests(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_components(self, quantization_config: TorchAoConfig): + model_id = "black-forest-labs/FLUX.1-dev" + transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") + text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 512, + "width": 512, + "num_inference_steps": 20, + "output_type": "np", + "generator": generator, + } + + return inputs + + def _test_quant_type(self, quantization_config, expected_slice): + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components).to(dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten() + output_slice = np.concatenate((output[:16], output[-16:])) + + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_quantization(self): + # fmt: off + QUANTIZATION_TYPES_TO_TEST = [ + ("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])), + ("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])), + ] + + if TorchAoConfig._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), + ("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])), + ]) + # fmt: on + + for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) + self._test_quant_type(quantization_config, expected_slice) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() From f3771a8036d4996f7e58514bff845ae6f0d50370 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 13:50:02 +0100 Subject: [PATCH 19/27] fix --- .../quantizers/torchao/torchao_quantizer.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 82791550806c..b98bdd9c18f9 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -35,28 +35,28 @@ import torch import torch.nn as nn + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + # At the moment, only int8 is supported for integer quantization dtypes. + # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future + # to support more quantization methods, such as intx_weight_only. + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + ) + if is_torchao_available(): from torchao.quantization import quantize_ logger = logging.get_logger(__name__) -SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( - # At the moment, only int8 is supported for integer quantization dtypes. - # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future - # to support more quantization methods, such as intx_weight_only. - torch.int8, - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.uint1, - torch.uint2, - torch.uint3, - torch.uint4, - torch.uint5, - torch.uint6, - torch.uint7, -) - def _quantization_type(weight): from torchao.dtypes import AffineQuantizedTensor From de97a511e669fc5be1d25cc4b70c4575fad04ff5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 14:11:08 +0100 Subject: [PATCH 20/27] improve test check --- docs/source/en/quantization/torchao.md | 4 ++-- tests/quantization/torchao/test_torchao.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 78d69569a530..d8f57a00b280 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. --> # torchao -[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks). +[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks). Before you begin, make sure you have Pytorch version 2.5, or above, and TorchAO installed: @@ -21,7 +21,7 @@ pip install -U torch torchao ## Usage -Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. +Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. Loading pre-quantized models is supported as well! This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. ```python from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 274330a8bf6b..eb09119edf61 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -74,6 +74,7 @@ def forward(self, input, *args, **kwargs): if is_torchao_available(): from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor @require_torch @@ -494,6 +495,11 @@ def check_serialization_expected_slice(self, expected_slice): output = loaded_quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue( + isinstance( + loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor) + ) + ) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_serialization_expected_slice(self): From 101d10cfe2c7e7ad2b976ee707384160940b4cc4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 14:22:03 +0100 Subject: [PATCH 21/27] update docs --- docs/source/en/quantization/torchao.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index d8f57a00b280..57340c8d07e6 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -48,6 +48,15 @@ image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] image.save("output.png") ``` +TorchAO offers seamless compatibility with `torch.compile`, setting it apart from other quantization methods. This ensures one to achieve remarkable speedups with ease. + +```python +# In the above code, add the following after initializing the transformer +transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) +``` + +For speed/memory benchmarks on Flux/CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). + Additionally, TorchAO supports an automatic quantization API exposed with [`autoquant`](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. This can directly be used with the underlying modeling components at the moment, but Diffusers will also expose an autoquant configuration option in the future. ## Resources From edd98dba1aeb78d8a73f8642b34552585d96c5cd Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 14:22:12 +0100 Subject: [PATCH 22/27] bnb device map check --- src/diffusers/models/modeling_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0f76d3dbdf23..ed4f71e7529d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -671,6 +671,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer = None if hf_quantizer is not None: + is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" + if is_bnb_quantization_method and device_map is not None: + raise NotImplementedError( + "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." + ) + hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) From cc70887e54f6fbb948685efaf09d0a6e3577162d Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 20:31:59 +0100 Subject: [PATCH 23/27] update docs --- docs/source/en/quantization/torchao.md | 41 +++++++++++ .../quantizers/quantization_config.py | 68 +++++++++++-------- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 57340c8d07e6..6730b5971baf 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -59,6 +59,47 @@ For speed/memory benchmarks on Flux/CogVideoX, please refer to the table [here]( Additionally, TorchAO supports an automatic quantization API exposed with [`autoquant`](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. This can directly be used with the underlying modeling components at the moment, but Diffusers will also expose an autoquant configuration option in the future. +The `TorchAoConfig` class accepts three parameters: +- `quant_type`: A string value mentioning one of the quantization types below. +- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`. +- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`. + +## Supported quantization types + +Broadly, quantization in the follow data types is supported: `int8`, `float3-float8` and `uint1-uint7`. Among these types, there exists weight-only quantization techniques and weight + dynamic-activation quantization techniques. + +Weight-only quantization refers to storing the model weights in a specific low-bit data type but performing computation in a higher precision data type, like `bfloat16`. This lowers the memory requirements from model weights, but retains the memory peaks for activation computation. + +Dynamic Activation quantization refers to storing the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly before settling for your favourite quantization method. + +The quantization methods supported are as follows: + +- **Integer quantization:** + - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` + - Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq` + - Documentation shorthands/Common speak: `int_a16w4`, `int_a8w4`, `int_a16w8`, `int_a8w8` + +- **Floating point 8-bit quantization:** + - Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` + - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row`, `float8sq` + - Documentation shorthands/Common speak: `float8_e5m2_a16w8`, `float8_e4m3_a16w8`, `float_a8w8`, `float_a16w8` + +- **Floating point X-bit quantization:** + - Full function names: `fpx_weight_only` + - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must be satisfied for a given shorthand notation. + - Documentation shorthands/Common speak: `float_a16w3`, `float_a16w4`, `float_a16w5`, `float_a16w6`, `float_a16w7`, `float_a16w8` + +- **Unsigned Integer quantization:** + - Full function names: `uintx_weight_only` + - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` + - Documentation shorthands/Common speak: `uint_a16w1`, `uint_a16w2`, `uint_a16w3`, `uint_a16w4`, `uint_a16w5`, `uint_a16w6`, `uint_a16w7` + +The "Documentation shorthands/Common speak" representation is simply the underlying storage dtype with the number of bits for storing activations and weights respectively. + +Note that some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows the usage of the quantization methods as specified in the TorchAO docs as-is, while also making it convenient to use easy to remember shorthand notations. + +It is recommended to check out the official TorchAO Documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. + ## Resources - [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 7022e28b8db6..bc5f42f0b238 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -400,9 +400,35 @@ class TorchAoConfig(QuantizationConfigMixin): Args: quant_type (`str`): - The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and - `int8_dynamic_activation_int8_weight`. - modules_to_not_convert (`list`, *optional*, default to `None`): + The type of quantization we want to use, currently supporting: + - **Integer quantization:** + - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, + `int8_weight_only`, `int8_dynamic_activation_int8_weight` + - Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq` + - Documentation shorthands/Common speak: `int_a16w4`, `int_a8w4`, `int_a16w8`, `int_a8w8` + + - **Floating point 8-bit quantization:** + - Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`, + `float8_static_activation_float8_weight` + - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, + `float8_e4m3_tensor`, `float8_e4m3_row`, `float8sq` + - Documentation shorthands/Common speak: `float8_e5m2_a16w8`, `float8_e4m3_a16w8`, `float_a8w8`, + `float_a16w8` + + - **Floating point X-bit quantization:** + - Full function names: `fpx_weight_only` + - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number + of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must + be satisfied for a given shorthand notation. + - Documentation shorthands/Common speak: `float_a16w3`, `float_a16w4`, `float_a16w5`, + `float_a16w6`, `float_a16w7`, `float_a16w8` + + - **Unsigned Integer quantization:** + - Full function names: `uintx_weight_only` + - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` + - Documentation shorthands/Common speak: `uint_a16w1`, `uint_a16w2`, `uint_a16w3`, `uint_a16w4`, + `uint_a16w5`, `uint_a16w6`, `uint_a16w7` + modules_to_not_convert (`List[str]`, *optional*, default to `None`): The list of modules to not quantize, useful for quantizing models that explicitly require to have some modules left in their original precision. kwargs (`Dict[str, Any]`, *optional*): @@ -425,7 +451,7 @@ class TorchAoConfig(QuantizationConfigMixin): ``` """ - def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs) -> None: + def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None: self.quant_method = QuantizationMethod.TORCHAO self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert @@ -511,23 +537,6 @@ def generate_fpx_quantization_types(bits: int): return types - def generate_uintx_quantization_types(bits: int): - UINTX_TO_DTYPE = { - 1: torch.uint1, - 2: torch.uint2, - 3: torch.uint3, - 4: torch.uint4, - 5: torch.uint5, - 6: torch.uint6, - 7: torch.uint7, - # 8: torch.uint8, # uint8 quantization is not supported - } - - types = {} - types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) - types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) - return types - INT4_QUANTIZATION_TYPES = { # int4 weight + bfloat16/float16 activation "int4wo": int4_weight_only, @@ -586,16 +595,15 @@ def generate_uintx_quantization_types(bits: int): } UINTX_QUANTIZATION_DTYPES = { - "uintx": uintx_weight_only, "uintx_weight_only": uintx_weight_only, - **generate_uintx_quantization_types(1), - **generate_uintx_quantization_types(2), - **generate_uintx_quantization_types(3), - **generate_uintx_quantization_types(4), - **generate_uintx_quantization_types(5), - **generate_uintx_quantization_types(6), - **generate_uintx_quantization_types(7), - # **generate_uintx_quantization_types(8), # uint8 quantization is not supported + "uint1wo": partial(uintx_weight_only, dtype=torch.uint1), + "uint2wo": partial(uintx_weight_only, dtype=torch.uint2), + "uint3wo": partial(uintx_weight_only, dtype=torch.uint3), + "uint4wo": partial(uintx_weight_only, dtype=torch.uint4), + "uint5wo": partial(uintx_weight_only, dtype=torch.uint5), + "uint6wo": partial(uintx_weight_only, dtype=torch.uint6), + "uint7wo": partial(uintx_weight_only, dtype=torch.uint7), + # "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported } SHORTHAND_QUANTIZATION_TYPES = { From 5f75db23c5d6e7e2239be0e9433808b35f7d12a1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 7 Dec 2024 01:08:11 +0530 Subject: [PATCH 24/27] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Sayak Paul --- docs/source/en/quantization/overview.md | 2 +- docs/source/en/quantization/torchao.md | 27 +++++++++++++------------ src/diffusers/utils/import_utils.py | 2 +- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 99d381e3a537..151b22a607a4 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be ## When to use what? -This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes` and `torchao`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file +Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use. \ No newline at end of file diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 6730b5971baf..4526455c5a59 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -11,17 +11,18 @@ specific language governing permissions and limitations under the License. --> # torchao -[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks). +[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more. -Before you begin, make sure you have Pytorch version 2.5, or above, and TorchAO installed: +Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed. ```bash pip install -U torch torchao ``` -## Usage -Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. Loading pre-quantized models is supported as well! This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. +Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. + +The example below only quantizes the weights to int8. ```python from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig @@ -48,16 +49,16 @@ image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] image.save("output.png") ``` -TorchAO offers seamless compatibility with `torch.compile`, setting it apart from other quantization methods. This ensures one to achieve remarkable speedups with ease. +TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code. ```python # In the above code, add the following after initializing the transformer transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) ``` -For speed/memory benchmarks on Flux/CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). +For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware. -Additionally, TorchAO supports an automatic quantization API exposed with [`autoquant`](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. This can directly be used with the underlying modeling components at the moment, but Diffusers will also expose an autoquant configuration option in the future. +torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future. The `TorchAoConfig` class accepts three parameters: - `quant_type`: A string value mentioning one of the quantization types below. @@ -66,11 +67,11 @@ The `TorchAoConfig` class accepts three parameters: ## Supported quantization types -Broadly, quantization in the follow data types is supported: `int8`, `float3-float8` and `uint1-uint7`. Among these types, there exists weight-only quantization techniques and weight + dynamic-activation quantization techniques. +torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7. -Weight-only quantization refers to storing the model weights in a specific low-bit data type but performing computation in a higher precision data type, like `bfloat16`. This lowers the memory requirements from model weights, but retains the memory peaks for activation computation. +Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation. -Dynamic Activation quantization refers to storing the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly before settling for your favourite quantization method. +Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly. The quantization methods supported are as follows: @@ -94,11 +95,11 @@ The quantization methods supported are as follows: - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` - Documentation shorthands/Common speak: `uint_a16w1`, `uint_a16w2`, `uint_a16w3`, `uint_a16w4`, `uint_a16w5`, `uint_a16w6`, `uint_a16w7` -The "Documentation shorthands/Common speak" representation is simply the underlying storage dtype with the number of bits for storing activations and weights respectively. +The "Documentation shorthands/Common speak" refers to the underlying storage dtype with the number of bits for storing activations and weights, respectively. For example, int_a16w8 stores the activations in 16-bit and the weights in 8-bit. -Note that some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows the usage of the quantization methods as specified in the TorchAO docs as-is, while also making it convenient to use easy to remember shorthand notations. +Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations. -It is recommended to check out the official TorchAO Documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. +Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. ## Resources diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 274bda74a391..e48c245fa2dd 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -344,7 +344,7 @@ def is_timm_available(): if _is_torchao_available: try: _torchao_version = importlib_metadata.version("torchao") - logger.debug(f"Successfully import gguf version {_torchao_version}") + logger.debug(f"Successfully import torchao version {_torchao_version}") except importlib_metadata.PackageNotFoundError: _is_torchao_available = False From 7d9d1dc02f148c827827f8d7561cbe24d9f31e02 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 22:28:20 +0100 Subject: [PATCH 25/27] address review comments --- docs/source/en/quantization/torchao.md | 29 ++++---------- .../quantizers/quantization_config.py | 39 +------------------ tests/quantization/torchao/test_torchao.py | 7 +++- 3 files changed, 14 insertions(+), 61 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 4526455c5a59..bd5c7697a0f7 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -45,7 +45,7 @@ pipe = FluxPipeline.from_pretrained( pipe.to("cuda") prompt = "A cat holding a sign that says hello world" -image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] +image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0] image.save("output.png") ``` @@ -75,27 +75,12 @@ Dynamic activation quantization stores the model weights in a low-bit dtype, whi The quantization methods supported are as follows: -- **Integer quantization:** - - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` - - Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq` - - Documentation shorthands/Common speak: `int_a16w4`, `int_a8w4`, `int_a16w8`, `int_a8w8` - -- **Floating point 8-bit quantization:** - - Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` - - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row`, `float8sq` - - Documentation shorthands/Common speak: `float8_e5m2_a16w8`, `float8_e4m3_a16w8`, `float_a8w8`, `float_a16w8` - -- **Floating point X-bit quantization:** - - Full function names: `fpx_weight_only` - - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must be satisfied for a given shorthand notation. - - Documentation shorthands/Common speak: `float_a16w3`, `float_a16w4`, `float_a16w5`, `float_a16w6`, `float_a16w7`, `float_a16w8` - -- **Unsigned Integer quantization:** - - Full function names: `uintx_weight_only` - - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` - - Documentation shorthands/Common speak: `uint_a16w1`, `uint_a16w2`, `uint_a16w3`, `uint_a16w4`, `uint_a16w5`, `uint_a16w6`, `uint_a16w7` - -The "Documentation shorthands/Common speak" refers to the underlying storage dtype with the number of bits for storing activations and weights, respectively. For example, int_a16w8 stores the activations in 16-bit and the weights in 8-bit. +| **Category** | **Full Function Names** | **Shorthands** | +|--------------|-------------------------|----------------| +| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` | +| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` | +| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` | +| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` | Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations. diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index bc5f42f0b238..4aeb75ab704c 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -405,29 +405,22 @@ class TorchAoConfig(QuantizationConfigMixin): - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` - Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq` - - Documentation shorthands/Common speak: `int_a16w4`, `int_a8w4`, `int_a16w8`, `int_a8w8` - **Floating point 8-bit quantization:** - Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, - `float8_e4m3_tensor`, `float8_e4m3_row`, `float8sq` - - Documentation shorthands/Common speak: `float8_e5m2_a16w8`, `float8_e4m3_a16w8`, `float_a8w8`, - `float_a16w8` + `float8_e4m3_tensor`, `float8_e4m3_row`, - **Floating point X-bit quantization:** - Full function names: `fpx_weight_only` - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must be satisfied for a given shorthand notation. - - Documentation shorthands/Common speak: `float_a16w3`, `float_a16w4`, `float_a16w5`, - `float_a16w6`, `float_a16w7`, `float_a16w8` - **Unsigned Integer quantization:** - Full function names: `uintx_weight_only` - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` - - Documentation shorthands/Common speak: `uint_a16w1`, `uint_a16w2`, `uint_a16w3`, `uint_a16w4`, - `uint_a16w5`, `uint_a16w6`, `uint_a16w7` modules_to_not_convert (`List[str]`, *optional*, default to `None`): The list of modules to not quantize, useful for quantizing models that explicitly require to have some modules left in their original precision. @@ -584,7 +577,6 @@ def generate_fpx_quantization_types(bits: int): **generate_float8dq_types(torch.float8_e4m3fn), # float8 weight + float8 activation (static) "float8_static_activation_float8_weight": float8_static_activation_float8_weight, - "float8sq": float8_static_activation_float8_weight, # For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly # fpx weight + bfloat16/float16 activation **generate_fpx_quantization_types(3), @@ -606,42 +598,13 @@ def generate_fpx_quantization_types(bits: int): # "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported } - SHORTHAND_QUANTIZATION_TYPES = { - "int_a16w4": int4_weight_only, - "int_a8w4": int8_dynamic_activation_int4_weight, - "int_a16w8": int8_weight_only, - "int_a8w8": int8_dynamic_activation_int8_weight, - "uint_a16w1": partial(uintx_weight_only, dtype=torch.uint1), - "uint_a16w2": partial(uintx_weight_only, dtype=torch.uint2), - "uint_a16w3": partial(uintx_weight_only, dtype=torch.uint3), - "uint_a16w4": partial(uintx_weight_only, dtype=torch.uint4), - "uint_a16w5": partial(uintx_weight_only, dtype=torch.uint5), - "uint_a16w6": partial(uintx_weight_only, dtype=torch.uint6), - "uint_a16w7": partial(uintx_weight_only, dtype=torch.uint7), - # "uint_a16w8": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported - } - - SHORTHAND_FLOAT_QUANTIZATION_TYPES = { - "float_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), - "float_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), - "float_a8w8": float8_dynamic_activation_float8_weight, - "float_a16w3": partial(fpx_weight_only, ebits=2, mbits=0), - "float_a16w4": partial(fpx_weight_only, ebits=2, mbits=1), - "float_a16w5": partial(fpx_weight_only, ebits=3, mbits=1), - "float_a16w6": partial(fpx_weight_only, ebits=3, mbits=2), - "float_a16w7": partial(fpx_weight_only, ebits=4, mbits=2), - "float_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), - } - QUANTIZATION_TYPES = {} QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) - QUANTIZATION_TYPES.update(SHORTHAND_QUANTIZATION_TYPES) if cls._is_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) - QUANTIZATION_TYPES.update(SHORTHAND_FLOAT_QUANTIZATION_TYPES) return QUANTIZATION_TYPES else: diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index eb09119edf61..55bdb639c3b6 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -33,6 +33,7 @@ enable_full_determinism, is_torch_available, is_torchao_available, + nightly, require_torch, require_torch_gpu, require_torchao_version_greater, @@ -280,7 +281,8 @@ def test_int4wo_quant_bfloat16_conversion(self): def test_offload(self): """ - Test if the quantized model int4 weight-only is working properly with cpu/disk offload. + Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies + that the device map is correctly set (in the `hf_device_map` attribute of the model). """ device_map_offload = { @@ -306,6 +308,8 @@ def test_offload(self): offload_folder=offload_folder, ) + self.assertTrue(quantized_model.hf_device_map == device_map_offload) + output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() @@ -539,6 +543,7 @@ class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): @require_torch_gpu @require_torchao_version_greater("0.6.0") @slow +@nightly class SlowTorchAoTests(unittest.TestCase): def tearDown(self): gc.collect() From e9fccb6bbbfad5a86b99f866f7a39a1439b65423 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 23:02:59 +0100 Subject: [PATCH 26/27] update --- .../quantizers/torchao/torchao_quantizer.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index b98bdd9c18f9..8b28a403e6f0 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -139,14 +139,34 @@ def update_torch_dtype(self, torch_dtype): return torch_dtype def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + quant_type = self.quantization_config.quant_type + + if quant_type.startswith("int8") or quant_type.startswith("int4"): + # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8 + return torch.int8 + elif quant_type == "uintx_weight_only": + return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) + elif quant_type.startswith("uint"): + return { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + }[int(quant_type[4])] + elif quant_type.startswith("float") or quant_type.startswith("fp"): + return torch.bfloat16 + if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION): return target_dtype # We need one of the supported dtypes to be selected in order for accelerate to determine - # the total size of modules/parameters for auto device placement. This method will not be - # called when device_map is not "auto". + # the total size of modules/parameters for auto device placement. + possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"] raise ValueError( - f"You are using `device_map='auto'` on a TorchAO quantized model but a suitable target dtype " + f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype " f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the " f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." ) From bc874fcf8206d9605cc9bf4aaf5ed66816a445ab Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 23:19:50 +0100 Subject: [PATCH 27/27] add nightly marker for torch.compile test --- tests/quantization/torchao/test_torchao.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 55bdb639c3b6..5c71fc4e0ae7 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -365,6 +365,7 @@ def test_training(self): self.assertTrue(module.adapter[1].weight.grad is not None) self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + @nightly def test_torch_compile(self): r"""Test that verifies if torch.compile works with torchao quantization.""" quantization_config = TorchAoConfig("int8_weight_only")