Skip to content

Commit 11d8e3c

Browse files
[Quantization] support pass MappingType for TorchAoConfig (huggingface#10927)
* [Quantization] support pass MappingType for TorchAoConfig * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 97fda1b commit 11d8e3c

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ class QuantizationMethod(str, Enum):
4747
TORCHAO = "torchao"
4848

4949

50+
if is_torchao_available:
51+
from torchao.quantization.quant_primitives import MappingType
52+
53+
class TorchAoJSONEncoder(json.JSONEncoder):
54+
def default(self, obj):
55+
if isinstance(obj, MappingType):
56+
return obj.name
57+
return super().default(obj)
58+
59+
5060
@dataclass
5161
class QuantizationConfigMixin:
5262
"""
@@ -673,4 +683,6 @@ def __repr__(self):
673683
```
674684
"""
675685
config_dict = self.to_dict()
676-
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
686+
return (
687+
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
688+
)

tests/quantization/torchao/test_torchao.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def forward(self, input, *args, **kwargs):
7676
if is_torchao_available():
7777
from torchao.dtypes import AffineQuantizedTensor
7878
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
79+
from torchao.quantization.quant_primitives import MappingType
7980
from torchao.utils import get_model_size_in_bytes
8081

8182

@@ -122,6 +123,19 @@ def test_repr(self):
122123
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
123124
self.assertEqual(quantization_repr, expected_repr)
124125

126+
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
127+
expected_repr = """TorchAoConfig {
128+
"modules_to_not_convert": null,
129+
"quant_method": "torchao",
130+
"quant_type": "int4dq",
131+
"quant_type_kwargs": {
132+
"act_mapping_type": "SYMMETRIC",
133+
"group_size": 64
134+
}
135+
}""".replace(" ", "").replace("\n", "")
136+
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
137+
self.assertEqual(quantization_repr, expected_repr)
138+
125139

126140
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
127141
@require_torch

0 commit comments

Comments
 (0)