Skip to content

Commit 6821dbe

Browse files
committed
[Quantization] support pass MappingType for TorchAoConfig
1 parent 37a5f1b commit 6821dbe

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from packaging import version
3434

3535
from ..utils import is_torch_available, is_torchao_available, logging
36+
from torchao.quantization.quant_primitives import MappingType
3637

3738

3839
if is_torch_available():
@@ -46,6 +47,11 @@ class QuantizationMethod(str, Enum):
4647
GGUF = "gguf"
4748
TORCHAO = "torchao"
4849

50+
class CustomJSONEncoder(json.JSONEncoder):
51+
def default(self, obj):
52+
if isinstance(obj, MappingType):
53+
return obj.name
54+
return super().default(obj)
4955

5056
@dataclass
5157
class QuantizationConfigMixin:
@@ -673,4 +679,4 @@ def __repr__(self):
673679
```
674680
"""
675681
config_dict = self.to_dict()
676-
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
682+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomJSONEncoder)}\n"

tests/quantization/torchao/test_torchao.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def forward(self, input, *args, **kwargs):
7777
from torchao.dtypes import AffineQuantizedTensor
7878
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
7979
from torchao.utils import get_model_size_in_bytes
80+
from torchao.quantization.quant_primitives import MappingType
8081

8182

8283
@require_torch
@@ -122,6 +123,19 @@ def test_repr(self):
122123
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
123124
self.assertEqual(quantization_repr, expected_repr)
124125

126+
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
127+
expected_repr = """TorchAoConfig {
128+
"modules_to_not_convert": null,
129+
"quant_method": "torchao",
130+
"quant_type": "int4dq",
131+
"quant_type_kwargs": {
132+
"act_mapping_type": "SYMMETRIC",
133+
"group_size": 64
134+
}
135+
}""".replace(" ", "").replace("\n", "")
136+
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
137+
self.assertEqual(quantization_repr, expected_repr)
138+
125139

126140
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
127141
@require_torch

0 commit comments

Comments
 (0)