@@ -77,6 +77,7 @@ def forward(self, input, *args, **kwargs):
7777 from torchao .dtypes import AffineQuantizedTensor
7878 from torchao .quantization .linear_activation_quantized_tensor import LinearActivationQuantizedTensor
7979 from torchao .utils import get_model_size_in_bytes
80+ from torchao .quantization .quant_primitives import MappingType
8081
8182
8283@require_torch
@@ -122,6 +123,19 @@ def test_repr(self):
122123 quantization_repr = repr (quantization_config ).replace (" " , "" ).replace ("\n " , "" )
123124 self .assertEqual (quantization_repr , expected_repr )
124125
126+ quantization_config = TorchAoConfig ("int4dq" , group_size = 64 , act_mapping_type = MappingType .SYMMETRIC )
127+ expected_repr = """TorchAoConfig {
128+ "modules_to_not_convert": null,
129+ "quant_method": "torchao",
130+ "quant_type": "int4dq",
131+ "quant_type_kwargs": {
132+ "act_mapping_type": "SYMMETRIC",
133+ "group_size": 64
134+ }
135+ }""" .replace (" " , "" ).replace ("\n " , "" )
136+ quantization_repr = repr (quantization_config ).replace (" " , "" ).replace ("\n " , "" )
137+ self .assertEqual (quantization_repr , expected_repr )
138+
125139
126140# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
127141@require_torch
0 commit comments