Skip to content

Commit 3a5819e

Browse files
authored
Add exportable coreml codebook quantization op (#2443)
Summary: Added CoreML codebook quant (Palettization): https://apple.github.io/coremltools/docs-guides/source/opt-palettization-overview.html#palettization-overview * supports group_size `per_grouped_channel` * doesn't support vector quantization yet, but will be easy to turn on if needed * ops added: choose_qparams_and_quantize_codebook, dequantize_codebook * also enabled support for export, these two ops will be preserved after exporta * Added CodebookWeightOnlyConfig(dtype, group_size) that can be used with quantize_ to quantize the Tensor Test Plan: python test/prototype/test_coreml_codebook.py Reviewers: Subscribers: Tasks: Tags:
1 parent 994a4ba commit 3a5819e

File tree

6 files changed

+525
-6
lines changed

6 files changed

+525
-6
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
8+
import torch
9+
10+
from torchao.prototype.quantization.codebook_coreml import (
11+
CodebookQuantizedTensor,
12+
CodebookWeightOnlyConfig,
13+
choose_qparams_and_quantize_codebook_coreml,
14+
)
15+
from torchao.quantization import quantize_
16+
from torchao.quantization.utils import compute_error
17+
from torchao.testing.utils import skip_if_no_cuda
18+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, is_package_at_least
19+
20+
21+
@unittest.skipIf(
22+
not is_package_at_least("coremltools", "8.3.0"), "Requires coremltools >= 8.3.0"
23+
)
24+
class TestCodebookQuantization(unittest.TestCase):
25+
def setUp(self):
26+
torch.manual_seed(123)
27+
self.input = torch.randn(100, 256, dtype=torch.float32)
28+
self.code_dtype = torch.uint8
29+
self.block_size = [-1, 4]
30+
self.nbits = 8
31+
32+
def test_choose_qparams_codebook(self):
33+
codebook, wq = choose_qparams_and_quantize_codebook_coreml(
34+
self.input,
35+
self.code_dtype,
36+
self.block_size,
37+
)
38+
group_size = self.block_size[-1]
39+
self.assertEqual(codebook.shape, (256 // group_size, 2**self.nbits, 1))
40+
self.assertEqual(wq.shape, (100, 256))
41+
42+
self.assertFalse(torch.isnan(codebook).any())
43+
self.assertFalse(torch.isnan(wq).any())
44+
45+
def test_codebook_quantized_tensor_from_float(self):
46+
cqt = CodebookQuantizedTensor.from_float(
47+
self.input,
48+
self.code_dtype,
49+
self.block_size,
50+
)
51+
52+
dequant = cqt.dequantize()
53+
sqnr = compute_error(dequant, self.input)
54+
self.assertGreater(sqnr, 30)
55+
56+
def test_codebook_quantized_tensor_from_float2(self):
57+
block_size = [-1, 16]
58+
code_dtype = torch.uint4
59+
60+
cqt = CodebookQuantizedTensor.from_float(
61+
self.input,
62+
code_dtype,
63+
block_size,
64+
)
65+
66+
dequant = cqt.dequantize()
67+
68+
sqnr = compute_error(dequant, self.input)
69+
self.assertGreater(sqnr, 18)
70+
71+
def test_quantize_api(self):
72+
m = torch.nn.Sequential(torch.nn.Linear(64, 64))
73+
quantize_(
74+
m,
75+
CodebookWeightOnlyConfig(dtype=self.code_dtype, block_size=self.block_size),
76+
)
77+
assert type(m[0].weight) == CodebookQuantizedTensor
78+
79+
@skip_if_no_cuda()
80+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "requires 2.6+.")
81+
def test_export(self):
82+
m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(torch.float32)
83+
quantize_(m, CodebookWeightOnlyConfig(self.code_dtype, self.block_size))
84+
example_inputs = (torch.randn(1, 128, dtype=torch.float32),)
85+
m = torch.export.export(m, example_inputs).module()
86+
targets = [n.target for n in m.graph.nodes]
87+
self.assertTrue(torch.ops.quant.dequantize_codebook.default in targets)
88+
89+
90+
if __name__ == "__main__":
91+
unittest.main()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .api import CodebookWeightOnlyConfig
2+
from .codebook_ops import (
3+
choose_qparams_and_quantize_codebook_coreml,
4+
dequantize_codebook,
5+
)
6+
from .codebook_quantized_tensor import CodebookQuantizedTensor
7+
8+
__all__ = [
9+
"CodebookQuantizedTensor",
10+
"CodebookWeightOnlyConfig",
11+
"choose_qparams_and_quantize_codebook_coreml",
12+
"dequantize_codebook",
13+
]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import List
9+
10+
import torch
11+
12+
from torchao.core.config import AOBaseConfig
13+
from torchao.prototype.quantization.codebook_coreml.codebook_quantized_tensor import (
14+
CodebookQuantizedTensor,
15+
)
16+
from torchao.quantization.transform_module import (
17+
register_quantize_module_handler,
18+
)
19+
from torchao.utils import is_package_at_least
20+
21+
22+
@dataclass
23+
class CodebookWeightOnlyConfig(AOBaseConfig):
24+
dtype: torch.dtype
25+
block_size: List[int]
26+
27+
28+
@register_quantize_module_handler(CodebookWeightOnlyConfig)
29+
def _codebook_weight_only_transform(
30+
module: torch.nn.Module,
31+
config: CodebookWeightOnlyConfig,
32+
):
33+
"""
34+
Applies codebook weight-only quantization to linear layers.
35+
36+
Args:
37+
dtype: torch.uint1 to torch.uint8, torch.int32 supported.
38+
Returns:
39+
Callable for quantization transformation.
40+
"""
41+
if not is_package_at_least("coremltools", "8.3.0"):
42+
raise ImportError("Requires coremltools >= 8.3.0")
43+
44+
dtype = config.dtype
45+
block_size = config.block_size
46+
weight = module.weight
47+
48+
quantized_weight = CodebookQuantizedTensor.from_float(
49+
weight,
50+
dtype,
51+
block_size,
52+
)
53+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
54+
return module
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import List, Optional, Tuple
7+
8+
import torch
9+
10+
from torchao.quantization.quant_primitives import (
11+
_DTYPE_TO_BIT_WIDTH,
12+
_SUB_BYTE_UINT_BOUNDS,
13+
)
14+
from torchao.utils import _register_custom_op
15+
16+
quant_lib = torch.library.Library("quant", "FRAGMENT")
17+
register_custom_op = _register_custom_op(quant_lib)
18+
19+
20+
# wrapper around coreml util: https://github.com/apple/coremltools/blob/1c0e5cb1c1e3ab759af107b54f2be18b7c03f8aa/coremltools/models/neural_network/quantization_utils.py#L363
21+
@torch.no_grad
22+
@register_custom_op
23+
def choose_qparams_and_quantize_codebook_coreml(
24+
input_tensor: torch.Tensor,
25+
code_dtype: torch.dtype,
26+
block_size: List[int],
27+
force_kmeans1d: bool = False,
28+
cluster_dim: int = 1,
29+
vector_axis: Optional[int] = None,
30+
) -> Tuple[torch.Tensor, torch.Tensor]:
31+
"""
32+
Initialize the codebook using k-means clustering on blocks of the input tensor.
33+
34+
Args:
35+
input_tensor (torch.Tensor): The input tensor to be quantized.
36+
code_dtype (torch.dtype): The dtype for the codes. [torch.uint1, ..., torch.uint8]
37+
block_size (List[int]): the size for how many elements of last dimension of input_tensor
38+
belong to the same group and should share the same lookup table. let's say original
39+
shape is (N, K), and block_size of (N, group_size) or (-1, group_size),
40+
then the slice of (N, group_size) elements should use the same lookup
41+
table, and there will be (K // group_size) lookup tables
42+
force_kmeans1d (bool): Use kmeans1d regardless of number of weights
43+
cluster_dim (int): this means the size of the vector for vector lookup table quantization
44+
e.g. when cluster_dim is 4, instead of quantizing each scalar value one by one, we quantize
45+
the tensor in a unit of 4 element vectors, a vector of original tensor will be mapped to
46+
a vector in the codebook (lookup table) based on the indices.
47+
vector_axis (Optional[int]): used in vector quantization, see more docs in https://github.com/apple/coremltools/blob/1c0e5cb1c1e3ab759af107b54f2be18b7c03f8aa/coremltools/optimize/_utils.py#L371
48+
49+
Returns:
50+
Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8)
51+
"""
52+
assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8]
53+
assert len(block_size) == input_tensor.ndim
54+
block_size = block_size.copy()
55+
for i in range(input_tensor.ndim - 1):
56+
assert block_size[i] == -1 or block_size[i] == input_tensor.shape[i], (
57+
f"{block_size} not supported"
58+
)
59+
60+
group_size = block_size[-1]
61+
if group_size == -1:
62+
group_size = input_tensor.shape[-1]
63+
64+
assert input_tensor.shape[-1] % group_size == 0
65+
assert input_tensor.ndim == 2
66+
assert cluster_dim == 1, (
67+
f"only cluster_dim == 1 is supported right now, got {cluster_dim}"
68+
)
69+
70+
# for converting to numpy
71+
input_tensor = input_tensor.detach()
72+
# (N, K)
73+
original_shape = input_tensor.shape
74+
# (K // group_size)
75+
num_lut = input_tensor.shape[1] // group_size
76+
77+
# reshape to (N, K // group_size, group_size)
78+
input_tensor = input_tensor.reshape(input_tensor.shape[0], num_lut, group_size)
79+
from coremltools.models.neural_network.quantization_utils import (
80+
_get_kmeans_lookup_table_and_weight,
81+
)
82+
83+
nbits = _DTYPE_TO_BIT_WIDTH[code_dtype]
84+
if nbits > 8:
85+
print(f"Requested nbits: {nbits}, rewriting to 8 bits to reduce the size")
86+
nbits = 8
87+
88+
res_lut = []
89+
# each res_w[:, i, :] will use the same lookup table
90+
# res_w: (N, K // group_size, group_size)
91+
res_w = torch.zeros_like(input_tensor, dtype=torch.uint8)
92+
for i in range(num_lut):
93+
# lut: (2**nbits, 1)
94+
# w: (N * group_size)
95+
lut, w = _get_kmeans_lookup_table_and_weight(
96+
nbits, input_tensor[:, i, :], force_kmeans1d, cluster_dim, vector_axis
97+
)
98+
res_lut.append(torch.from_numpy(lut))
99+
res_w[:, i, :] = torch.from_numpy(w.reshape(input_tensor.shape[0], group_size))
100+
101+
# directly stack all lookup tables along dim 0
102+
# res_lut: (K // group_size, 2 ** nbits)
103+
res_lut = torch.stack(res_lut, dim=0)
104+
105+
# reshape back to (N, K)
106+
res_w = res_w.reshape(*original_shape)
107+
108+
return res_lut, res_w
109+
110+
111+
@register_custom_op
112+
def dequantize_codebook(
113+
codes: torch.Tensor,
114+
codebook: torch.Tensor,
115+
code_dtype: torch.dtype,
116+
block_size: List[int],
117+
output_dtype: torch.dtype = torch.float32,
118+
) -> torch.Tensor:
119+
"""
120+
Reconstructs the original tensor from codes and the codebook.
121+
122+
Args:
123+
codes (torch.Tensor): Indices of codebook entries for each element
124+
shape (N, K) for scalar quantization
125+
codebook (torch.Tensor): Codebook tensor used for quantization,
126+
shape (K // group_size, 2 ** nbits) where K is the dim 1 shape of input
127+
code_dtype (torch.dtype): The logical dtype for the codes, [torch.uint1, ..., torch.uint8]
128+
Note that codes is stored in torch.uint8, this is just addtional information for dequantize op
129+
block_size (List[int]): a slice of elements with shape block_size will share the same lookup table
130+
only support (-1, ..., group_size) right now (all preceding dimensions has to match input)
131+
output_dtype (torch.dtype): dtype for the output tensor.
132+
133+
Returns:
134+
dequant (torch.Tensor): Reconstructed tensor, shape (N, K)
135+
136+
"""
137+
assert output_dtype in [
138+
torch.float32,
139+
torch.float16,
140+
torch.bfloat16,
141+
], f"Unsupported output dtype: {output_dtype}"
142+
143+
assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8]
144+
145+
assert len(block_size) == codes.ndim
146+
block_size = block_size.copy()
147+
for i in range(codes.ndim - 1):
148+
assert block_size[i] == -1 or block_size[i] == codes.shape[i], (
149+
f"{block_size} not supported"
150+
)
151+
152+
group_size = block_size[-1]
153+
if group_size == -1:
154+
group_size = codes.shape[-1]
155+
156+
assert codes.shape[-1] % group_size == 0
157+
K = codes.shape[-1]
158+
num_lut = K // group_size
159+
# (N, K)
160+
original_shape = codes.shape
161+
162+
# reshape to (N, num_lut, group_size)
163+
codes = codes.reshape(codes.shape[0], num_lut, group_size)
164+
dequant = torch.zeros_like(codes, dtype=output_dtype)
165+
166+
# do lookup for each lookup table
167+
# dequant shape: (N, num_lut, group_size)
168+
# codebook shape: (num_lut, 2 ** nbits)
169+
# codes shape: (N, num_lut, group_size)
170+
for i in range(num_lut):
171+
# dequant[:, i, :]: (N, group_size)
172+
# using squeeze to remove the training dim 1s after the lookup
173+
dequant[:, i, :] = codebook[i][codes[:, i, :]].squeeze()
174+
175+
dequant = dequant.reshape(*original_shape)
176+
return dequant.to(output_dtype)

0 commit comments

Comments
 (0)