diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index a6e4dd9ff5e5..440ef2bf6230 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -47,6 +47,16 @@ class QuantizationMethod(str, Enum): TORCHAO = "torchao" +if is_torchao_available: + from torchao.quantization.quant_primitives import MappingType + + class TorchAoJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, MappingType): + return obj.name + return super().default(obj) + + @dataclass class QuantizationConfigMixin: """ @@ -673,4 +683,6 @@ def __repr__(self): ``` """ config_dict = self.to_dict() - return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + return ( + f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n" + ) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index adcd605e5806..e14a1cc0369e 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -76,6 +76,7 @@ def forward(self, input, *args, **kwargs): if is_torchao_available(): from torchao.dtypes import AffineQuantizedTensor from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + from torchao.quantization.quant_primitives import MappingType from torchao.utils import get_model_size_in_bytes @@ -122,6 +123,19 @@ def test_repr(self): quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") self.assertEqual(quantization_repr, expected_repr) + quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC) + expected_repr = """TorchAoConfig { + "modules_to_not_convert": null, + "quant_method": "torchao", + "quant_type": "int4dq", + "quant_type_kwargs": { + "act_mapping_type": "SYMMETRIC", + "group_size": 64 + } + }""".replace(" ", "").replace("\n", "") + quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") + self.assertEqual(quantization_repr, expected_repr) + # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch