Skip to content

Commit ebfe173

Browse files
[Feat] Allow symmetric_no_clipping_error for KleidiAI kernels, update Readme and validate Kleidi INT4 quantization path (#2570)
* [Feat] Restore and validate KleidiAI INT4 quantization path using updated quantizer API - Switched to quantize_() with Int8DynamicActivationIntxWeightConfig - Validated the move of packed_linear_int8_dynamic_activation_intx_weight_layout.py in torchao/dtypes/uintx - Fixed handling of SYMMETRIC_NO_CLIPPING_ERR mapping type - Validated INT4 path on a 2-layer nn.Sequential model with torch.int4 weights - Compared SYMMETRIC vs SYMMETRIC_NO_CLIPPING_ERR across PerAxis and PerGroup granularities * [Fix]: Allow "SYMMETRIC_NO_CLIPPING_ERR" in Int8DynamicActivationIntxWeightConfig * [FEAT]: Add SYMMETRIC_NO_CLIPPING_ERR to tests * Update test_int8_dynamic_activation_intx_weight.py * Update test_int8_dynamic_activation_intx_weight.py * Update test_int8_dynamic_activation_intx_weight.py --------- Co-authored-by: Scott Roy <[email protected]>
1 parent 840b7ce commit ebfe173

File tree

5 files changed

+54
-35
lines changed

5 files changed

+54
-35
lines changed

LICENSE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
Copyright 2023 Meta
2+
All contributions by Arm:
3+
Copyright (c) 2024-2025 Arm Limited and/or its affiliates
24

35
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
46

torchao/experimental/docs/readme.md

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -96,38 +96,6 @@ quantize_(
9696
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() is also supported, but much slower on CPU
9797
),
9898
)
99-
```
100-
101-
KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows:
102-
103-
```python
104-
from torchao.dtypes import PlainLayout
105-
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
106-
PackedLinearInt8DynamicActivationIntxWeightLayout,
107-
)
108-
from torchao.experimental.quant_api import (
109-
int8_dynamic_activation_intx_weight,
110-
)
111-
from torchao.quantization.granularity import (
112-
PerGroup,
113-
PerRow,
114-
)
115-
from torchao.quantization.quant_api import quantize_
116-
from torchao.quantization.quant_primitives import MappingType
117-
118-
my_model = Model()
119-
120-
quantize_(
121-
my_model,
122-
int8_dynamic_activation_intx_weight(
123-
weight_dtype=torch.int4,
124-
granularity=PerGroup(32), # PerRow() is also supported
125-
has_weight_zeros=True, # Should be True
126-
weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error
127-
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"),
128-
),
129-
)
130-
```
13199

132100
If you get stuck, consult
133101
`torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py`

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-2025 Arm Limited and affiliates.
23
# All rights reserved.
34
#
45
# This source code is licensed under the license found in the
@@ -54,6 +55,7 @@ class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
5455
for weight_mapping_type in [
5556
MappingType.SYMMETRIC,
5657
MappingType.ASYMMETRIC,
58+
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
5759
]
5860
for weight_granularity in [
5961
PerGroup(128),
@@ -71,6 +73,12 @@ def test_accuracy(
7173
"""
7274
Checks the accuracy of packed layouts
7375
"""
76+
if (
77+
weight_dtype == torch.int1
78+
and weight_mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR
79+
):
80+
return
81+
7482
m = 3
7583
n = 1071
7684
k = 2048

torchao/quantization/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,40 @@ quantize_(model, FPXWeightOnlyConfig(3, 2))
205205

206206
You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype.
207207

208+
```
209+
210+
KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows:
211+
212+
```python
213+
from torchao.quantization.quant_api import (
214+
Int8DynamicActivationIntxWeightConfig,
215+
quantize_,
216+
)
217+
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
218+
PackedLinearInt8DynamicActivationIntxWeightLayout,
219+
Target,
220+
)
221+
from torchao.quantization.granularity import PerGroup, PerAxis
222+
from torchao.quantization.quant_primitives import MappingType
223+
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
224+
225+
my_model = Model()
226+
227+
# Set quantization layout
228+
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=Target.ATEN)
229+
230+
quantize_(
231+
my_model,
232+
Int8DynamicActivationIntxWeightConfig(
233+
weight_scale_dtype=torch.float32,
234+
weight_granularity=PerGroup(32), #PerAxis is also supported
235+
weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR, # MappingType.SYMMETRIC can also be used but increases error
236+
layout=layout,
237+
weight_dtype=torch.int4,
238+
),
239+
)
240+
```
241+
208242
## Affine Quantization Details
209243
Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_precision_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization.
210244

torchao/quantization/quant_api.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-2025 Arm Limited and affiliates.
23
# All rights reserved.
3-
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

@@ -862,8 +862,9 @@ def __post_init__(self):
862862
assert self.weight_mapping_type in [
863863
MappingType.ASYMMETRIC,
864864
MappingType.SYMMETRIC,
865+
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
865866
], (
866-
f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.weight_mapping_type}"
867+
f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.weight_mapping_type}"
867868
)
868869
assert self.act_mapping_type in [
869870
MappingType.ASYMMETRIC,
@@ -917,6 +918,12 @@ def _int8_dynamic_activation_intx_weight_transform(
917918
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype]
918919

919920
# We quantize with QDQLayout, and then construct the packed weight tensor later
921+
# set preserve_zero based on weight mapping type
922+
preserve_zero = weight_mapping_type in [
923+
MappingType.SYMMETRIC,
924+
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
925+
]
926+
920927
weight = to_affine_quantized_intx(
921928
input_float=weight,
922929
mapping_type=weight_mapping_type,
@@ -926,7 +933,7 @@ def _int8_dynamic_activation_intx_weight_transform(
926933
quant_max=quant_max,
927934
scale_dtype=weight_scale_dtype,
928935
zero_point_dtype=torch.int8,
929-
preserve_zero=(weight_mapping_type == MappingType.SYMMETRIC),
936+
preserve_zero=preserve_zero,
930937
zero_point_domain=ZeroPointDomain.INT,
931938
_layout=QDQLayout(),
932939
)

0 commit comments

Comments
 (0)