Skip to content

Commit f9bc52d

Browse files
authored
Updates LUT tensor and new convert API (#2984)
* up * up * up * up * up
1 parent c4d4799 commit f9bc52d

File tree

10 files changed

+317
-361
lines changed

10 files changed

+317
-361
lines changed

.github/workflows/regression_test_aarch64.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
pytest -s test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py
5555
pytest -s test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py
5656
pytest -s test/prototype/test_embedding.py
57-
pytest -s test/prototype/test_dynamic_activation_lut.py
57+
pytest -s test/prototype/test_int8_lut_tensor.py
5858
pytest -s test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py
5959
pytest -s test/prototype/test_parq.py
6060
- name: torchao/csrc/cpu - build and run C++ tests

test/prototype/test_dynamic_activation_lut.py renamed to test/prototype/test_int8_lut_tensor.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import platform
8-
import sys
97
from copy import deepcopy
108

119
import pytest
@@ -15,16 +13,14 @@
1513
StretchedIntxWeightConfig,
1614
StretchedUnifTorchaoQuantizer,
1715
)
18-
from torchao.prototype.quantization.dynamic_activation_lut import (
19-
StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig,
16+
from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import (
17+
_is_kernel_library_loaded,
2018
)
19+
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64
2120
from torchao.quantization import quantize_
2221
from torchao.quantization.granularity import PerAxis, PerGroup
23-
from torchao.quantization.quant_api import _is_linear
2422
from torchao.quantization.utils import compute_error
2523

26-
is_arm64_mac = sys.platform == "darwin" and platform.machine() == "arm64"
27-
2824

2925
class ToyLinearModel(torch.nn.Module):
3026
def __init__(self, d1=512, d2=256, d3=128, d4=8):
@@ -59,7 +55,9 @@ def run_before_and_after_tests():
5955
@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)])
6056
@pytest.mark.parametrize("bit_width", [1, 2, 3, 4])
6157
@pytest.mark.parametrize("lead_dim", [(5,), (2, 3)])
62-
@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac")
58+
@pytest.mark.skipif(
59+
not _is_kernel_library_loaded(), reason="Kernel library is not loaded"
60+
)
6361
def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
6462
torch.manual_seed(0)
6563
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
@@ -68,38 +66,22 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
6866
quant_min=quantizer.quant_min,
6967
quant_max=quantizer.quant_max,
7068
granularity=granularity,
71-
activation_quantization=None,
72-
version=1,
69+
activation_quantization="int8_asym_per_token",
7370
)
7471

7572
parq_model = ToyLinearModel(128, 256, 128, 1).to(dtype)
7673
activations = parq_model.example_inputs(lead_dim=lead_dim, dtype=dtype)
77-
parq_model_with_dyn_quant = deepcopy(parq_model)
7874
quantize_(parq_model, config)
7975

80-
# Apply dynamic activation to parq model. This will serve as the LUT reference
81-
dyn_act_config = deepcopy(config)
82-
dyn_act_config.activation_quantization = "int8_asym_per_token"
83-
quantize_(parq_model_with_dyn_quant, dyn_act_config, filter_fn=_is_linear)
84-
8576
# Convert PARQ model to lowbit LUT model
8677
lut_model = deepcopy(parq_model)
87-
conversion_config = (
88-
StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig(
89-
config.b, config.granularity
90-
)
91-
)
92-
quantize_(lut_model, conversion_config, filter_fn=conversion_config.get_filter_fn())
78+
_convert_model_for_aarch64(lut_model, tensor_type="int8_lut_tensor")
9379

9480
# Run both models and compare
9581
parq_out = parq_model(activations)
96-
parq_with_dyn_quant_out = parq_model_with_dyn_quant(activations)
9782
lut_out = lut_model(activations)
9883

99-
sqnr = compute_error(parq_out, parq_with_dyn_quant_out).item()
100-
assert sqnr > 20.0, f"sqnr {sqnr} is too low"
101-
102-
sqnr = compute_error(lut_out, parq_with_dyn_quant_out).item()
84+
sqnr = compute_error(parq_out, lut_out).item()
10385
if dtype == torch.float32:
10486
assert sqnr > 40.0, f"sqnr {sqnr} is too low"
10587
elif dtype == torch.bfloat16:
@@ -112,32 +94,27 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
11294
@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)])
11395
@pytest.mark.parametrize("bit_width", [1, 2, 3, 4])
11496
@pytest.mark.parametrize("lead_dim", [(5,), (2, 3)])
115-
@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac")
97+
@pytest.mark.skipif(
98+
not _is_kernel_library_loaded(), reason="Kernel library is not loaded"
99+
)
116100
def test_export(dtype, granularity, bit_width, lead_dim):
117101
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
118102
config = StretchedIntxWeightConfig(
119103
b=bit_width,
120104
quant_min=quantizer.quant_min,
121105
quant_max=quantizer.quant_max,
122106
granularity=granularity,
123-
activation_quantization=None,
124-
version=1,
107+
activation_quantization="int8_asym_per_token",
125108
)
126109

127110
parq_model = ToyLinearModel(128, 256, 128, 8).to(dtype)
128111
activations = parq_model.example_inputs(lead_dim=lead_dim)
129112
quantize_(parq_model, config)
130113

131-
conversion_config = (
132-
StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig(
133-
config.b, config.granularity
134-
)
135-
)
136-
quantize_(
137-
parq_model, conversion_config, filter_fn=conversion_config.get_filter_fn()
138-
)
114+
_convert_model_for_aarch64(parq_model)
139115

140116
ep = torch.export.export(parq_model, (activations,))
117+
141118
assert (
142119
f"torch.ops.torchao._linear_8bit_act_{bit_width}bit_weight.default"
143120
in ep.graph_module.code

torchao/prototype/parq/quant/config_torchao.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import types
12
from dataclasses import dataclass
23
from typing import Callable, Optional
34

@@ -17,6 +18,7 @@
1718
IntxWeightOnlyConfig,
1819
ModuleFqnToConfig,
1920
_int8_asymm_per_token_quant,
21+
_linear_extra_repr,
2022
)
2123
from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor
2224
from torchao.quantization.transform_module import register_quantize_module_handler
@@ -117,7 +119,12 @@ def _int8_dynamic_activation_stretched_intx_transform(
117119
weight = to_linear_activation_quantized(weight, _int8_asymm_per_token_quant)
118120
elif config.activation_quantization is not None:
119121
raise ValueError(f"Unsupported {config.activation_quantization=}")
122+
120123
module.weight = nn.Parameter(weight, requires_grad=False)
124+
125+
if isinstance(module, nn.Linear):
126+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
127+
121128
return module
122129

123130

torchao/prototype/quantization/dynamic_activation_lut/__init__.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

torchao/prototype/quantization/dynamic_activation_lut/api.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

0 commit comments

Comments
 (0)