From 5c9be75aeb466f81221b3adf51b85765e3c22c3c Mon Sep 17 00:00:00 2001 From: baymax591 Date: Tue, 24 Dec 2024 11:39:45 +0800 Subject: [PATCH 1/2] fix bug for torch.uint1-7 not support in torch<2.6 --- .../quantizers/torchao/torchao_quantizer.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 25cd4ad448e7..6eb4816dcad8 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -33,23 +33,29 @@ if is_torch_available(): import torch - import torch.nn as nn - - SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( - # At the moment, only int8 is supported for integer quantization dtypes. - # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future - # to support more quantization methods, such as intx_weight_only. - torch.int8, - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.uint1, - torch.uint2, - torch.uint3, - torch.uint4, - torch.uint5, - torch.uint6, - torch.uint7, - ) + + if version.parse(torch.__version__) >= version.parse('2.6'): + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + # At the moment, only int8 is supported for integer quantization dtypes. + # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future + # to support more quantization methods, such as intx_weight_only. + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + ) + else: + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + ) if is_torchao_available(): from torchao.quantization import quantize_ From 97827d747485c63d3af5432406de1c2852fc6aaf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 24 Dec 2024 06:29:05 +0100 Subject: [PATCH 2/2] up --- src/diffusers/quantizers/torchao/torchao_quantizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 6eb4816dcad8..5770e32c909e 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -23,7 +23,7 @@ from packaging import version -from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging +from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging from ..base import DiffusersQuantizer @@ -33,8 +33,9 @@ if is_torch_available(): import torch + import torch.nn as nn - if version.parse(torch.__version__) >= version.parse('2.6'): + if is_torch_version(">=", "2.5"): SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( # At the moment, only int8 is supported for integer quantization dtypes. # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future