diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 25cd4ad448e7..5770e32c909e 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -23,7 +23,7 @@ from packaging import version -from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging +from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging from ..base import DiffusersQuantizer @@ -35,21 +35,28 @@ import torch import torch.nn as nn - SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( - # At the moment, only int8 is supported for integer quantization dtypes. - # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future - # to support more quantization methods, such as intx_weight_only. - torch.int8, - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.uint1, - torch.uint2, - torch.uint3, - torch.uint4, - torch.uint5, - torch.uint6, - torch.uint7, - ) + if is_torch_version(">=", "2.5"): + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + # At the moment, only int8 is supported for integer quantization dtypes. + # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future + # to support more quantization methods, such as intx_weight_only. + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + ) + else: + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + ) if is_torchao_available(): from torchao.quantization import quantize_