Skip to content

Commit 5c9be75

Browse files
committed
fix bug for torch.uint1-7 not support in torch<2.6
1 parent c1e7fd5 commit 5c9be75

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,29 @@
3333

3434
if is_torch_available():
3535
import torch
36-
import torch.nn as nn
37-
38-
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
39-
# At the moment, only int8 is supported for integer quantization dtypes.
40-
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
41-
# to support more quantization methods, such as intx_weight_only.
42-
torch.int8,
43-
torch.float8_e4m3fn,
44-
torch.float8_e5m2,
45-
torch.uint1,
46-
torch.uint2,
47-
torch.uint3,
48-
torch.uint4,
49-
torch.uint5,
50-
torch.uint6,
51-
torch.uint7,
52-
)
36+
37+
if version.parse(torch.__version__) >= version.parse('2.6'):
38+
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
39+
# At the moment, only int8 is supported for integer quantization dtypes.
40+
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
41+
# to support more quantization methods, such as intx_weight_only.
42+
torch.int8,
43+
torch.float8_e4m3fn,
44+
torch.float8_e5m2,
45+
torch.uint1,
46+
torch.uint2,
47+
torch.uint3,
48+
torch.uint4,
49+
torch.uint5,
50+
torch.uint6,
51+
torch.uint7,
52+
)
53+
else:
54+
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
55+
torch.int8,
56+
torch.float8_e4m3fn,
57+
torch.float8_e5m2,
58+
)
5359

5460
if is_torchao_available():
5561
from torchao.quantization import quantize_

0 commit comments

Comments
 (0)