Skip to content

Commit ceb09f8

Browse files
committed
[Quantization] support pass MappingType for TorchAoConfig
1 parent 54043c3 commit ceb09f8

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ class QuantizationMethod(str, Enum):
4646
GGUF = "gguf"
4747
TORCHAO = "torchao"
4848

49+
if is_torchao_available:
50+
from torchao.quantization.quant_primitives import MappingType
51+
class TorchAoJSONEncoder(json.JSONEncoder):
52+
def default(self, obj):
53+
if isinstance(obj, MappingType):
54+
return obj.name
55+
return super().default(obj)
4956

5057
@dataclass
5158
class QuantizationConfigMixin:
@@ -673,4 +680,4 @@ def __repr__(self):
673680
```
674681
"""
675682
config_dict = self.to_dict()
676-
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
683+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"

tests/quantization/torchao/test_torchao.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def forward(self, input, *args, **kwargs):
7777
from torchao.dtypes import AffineQuantizedTensor
7878
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
7979
from torchao.utils import get_model_size_in_bytes
80+
from torchao.quantization.quant_primitives import MappingType
8081

8182

8283
@require_torch
@@ -122,6 +123,19 @@ def test_repr(self):
122123
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
123124
self.assertEqual(quantization_repr, expected_repr)
124125

126+
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
127+
expected_repr = """TorchAoConfig {
128+
"modules_to_not_convert": null,
129+
"quant_method": "torchao",
130+
"quant_type": "int4dq",
131+
"quant_type_kwargs": {
132+
"act_mapping_type": "SYMMETRIC",
133+
"group_size": 64
134+
}
135+
}""".replace(" ", "").replace("\n", "")
136+
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
137+
self.assertEqual(quantization_repr, expected_repr)
138+
125139

126140
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
127141
@require_torch

0 commit comments

Comments
 (0)