Skip to content

Commit ca5f788

Browse files
authored
Update coreml codebook (#2648)
* Update CoreML codebook APIs * up * up * up
1 parent dc36108 commit ca5f788

File tree

4 files changed

+115
-63
lines changed

4 files changed

+115
-63
lines changed

test/prototype/test_codebook_coreml.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
)
1515
from torchao.quantization import quantize_
1616
from torchao.quantization.utils import compute_error
17-
from torchao.testing.utils import skip_if_no_cuda
1817
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, is_package_at_least
1918

2019

@@ -36,7 +35,7 @@ def test_choose_qparams_codebook(self):
3635
self.block_size,
3736
)
3837
group_size = self.block_size[-1]
39-
self.assertEqual(codebook.shape, (256 // group_size, 2**self.nbits, 1))
38+
self.assertEqual(codebook.shape, (1, 256 // group_size, 2**self.nbits, 1))
4039
self.assertEqual(wq.shape, (100, 256))
4140

4241
self.assertFalse(torch.isnan(codebook).any())
@@ -76,7 +75,6 @@ def test_quantize_api(self):
7675
)
7776
assert type(m[0].weight) == CodebookQuantizedTensor
7877

79-
@skip_if_no_cuda()
8078
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "requires 2.6+.")
8179
def test_export(self):
8280
m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(torch.float32)

torchao/prototype/quantization/codebook_coreml/api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@ def _codebook_weight_only_transform(
4242
raise ImportError("Requires coremltools >= 8.3.0")
4343

4444
dtype = config.dtype
45-
block_size = config.block_size
4645
weight = module.weight
4746

4847
quantized_weight = CodebookQuantizedTensor.from_float(
4948
weight,
5049
dtype,
51-
block_size,
50+
config.block_size,
5251
)
5352
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
5453
return module

torchao/prototype/quantization/codebook_coreml/codebook_ops.py

Lines changed: 98 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ def choose_qparams_and_quantize_codebook_coreml(
3434
Args:
3535
input_tensor (torch.Tensor): The input tensor to be quantized.
3636
code_dtype (torch.dtype): The dtype for the codes. [torch.uint1, ..., torch.uint8]
37-
block_size (List[int]): the size for how many elements of last dimension of input_tensor
38-
belong to the same group and should share the same lookup table. let's say original
39-
shape is (N, K), and block_size of (N, group_size) or (-1, group_size),
40-
then the slice of (N, group_size) elements should use the same lookup
41-
table, and there will be (K // group_size) lookup tables
37+
block_size (List[int]): block sizes for how many elements in each dimension share
38+
the same lookup table (len(block_size) == input_tensor.dim())
39+
Each dimension of input_tensor must be divisible by the corresponding element of block_size
40+
Look up tables are indexed by {(di // bi) for i in input_tensor.dim()}
41+
For example, if the input tensor has shape (N, K), and block_size is (N, group_size), this means
42+
there is a lookup table for group_size columns, i.e., (K // group_size) total look up tables
4243
force_kmeans1d (bool): Use kmeans1d regardless of number of weights
4344
cluster_dim (int): this means the size of the vector for vector lookup table quantization
4445
e.g. when cluster_dim is 4, instead of quantizing each scalar value one by one, we quantize
@@ -48,43 +49,45 @@ def choose_qparams_and_quantize_codebook_coreml(
4849
4950
Returns:
5051
Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8)
52+
The LUT table has dimension (g0, .., g(N-1), 2**nbits, vec_dim), where:
53+
* The first N dimensions index over the different tables (gi = input_tensor.shape[i] // block_size[i] in each dimension)
54+
* The N + 1 dimension indexes over the nbit indices (2 ** nbits)
55+
* The N + 2 dimension indexes over the look up values (shape = 1 for scalar)
5156
"""
5257
assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8]
53-
assert len(block_size) == input_tensor.ndim
58+
nbits = _DTYPE_TO_BIT_WIDTH[code_dtype]
59+
assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}"
60+
61+
assert len(block_size) == input_tensor.dim()
5462
block_size = block_size.copy()
55-
for i in range(input_tensor.ndim - 1):
56-
assert block_size[i] == -1 or block_size[i] == input_tensor.shape[i], (
57-
f"{block_size} not supported"
63+
for i in range(len(block_size)):
64+
if block_size[i] == -1:
65+
block_size[i] = input_tensor.shape[i]
66+
assert block_size[i] >= 1 and input_tensor.shape[i] % block_size[i] == 0, (
67+
"block_size[i] must divide input_tensor.shape[i]"
5868
)
5969

60-
group_size = block_size[-1]
61-
if group_size == -1:
62-
group_size = input_tensor.shape[-1]
63-
64-
assert input_tensor.shape[-1] % group_size == 0
65-
assert input_tensor.ndim == 2
70+
assert input_tensor.dim() == 2, "Currently only rank 2 tensors are supported"
71+
assert block_size[0] == input_tensor.shape[0], (
72+
"Currently only support per-grouped channel granularity"
73+
)
6674
assert cluster_dim == 1, (
6775
f"only cluster_dim == 1 is supported right now, got {cluster_dim}"
6876
)
6977

78+
num_lut = input_tensor.shape[1] // block_size[1]
79+
group_size = block_size[1]
80+
7081
# for converting to numpy
7182
input_tensor = input_tensor.detach()
72-
# (N, K)
7383
original_shape = input_tensor.shape
74-
# (K // group_size)
75-
num_lut = input_tensor.shape[1] // group_size
7684

7785
# reshape to (N, K // group_size, group_size)
7886
input_tensor = input_tensor.reshape(input_tensor.shape[0], num_lut, group_size)
7987
from coremltools.models.neural_network.quantization_utils import (
8088
_get_kmeans_lookup_table_and_weight,
8189
)
8290

83-
nbits = _DTYPE_TO_BIT_WIDTH[code_dtype]
84-
if nbits > 8:
85-
print(f"Requested nbits: {nbits}, rewriting to 8 bits to reduce the size")
86-
nbits = 8
87-
8891
res_lut = []
8992
# each res_w[:, i, :] will use the same lookup table
9093
# res_w: (N, K // group_size, group_size)
@@ -102,6 +105,13 @@ def choose_qparams_and_quantize_codebook_coreml(
102105
# res_lut: (K // group_size, 2 ** nbits)
103106
res_lut = torch.stack(res_lut, dim=0)
104107

108+
# The final LUT should have dimension equal to input_tensor.dim() + 2
109+
# The first input_tensor.dim() dimensions index over the tables,
110+
# input_tensor.dim() + 1 indexes over the nbit indices
111+
# input_tensor.dim() + 2 are the look up values (shape = 1 for scalar)
112+
# res_lut: (N, K // group_size, 2 ** nbits, group_size)
113+
res_lut = res_lut.reshape(1, num_lut, 2**nbits, 1)
114+
105115
# reshape back to (N, K)
106116
res_w = res_w.reshape(*original_shape)
107117

@@ -112,7 +122,7 @@ def choose_qparams_and_quantize_codebook_coreml(
112122
def dequantize_codebook(
113123
codes: torch.Tensor,
114124
codebook: torch.Tensor,
115-
code_dtype: torch.dtype,
125+
nbits: int,
116126
block_size: List[int],
117127
output_dtype: torch.dtype = torch.float32,
118128
) -> torch.Tensor:
@@ -121,13 +131,14 @@ def dequantize_codebook(
121131
122132
Args:
123133
codes (torch.Tensor): Indices of codebook entries for each element
124-
shape (N, K) for scalar quantization
125-
codebook (torch.Tensor): Codebook tensor used for quantization,
126-
shape (K // group_size, 2 ** nbits) where K is the dim 1 shape of input
127-
code_dtype (torch.dtype): The logical dtype for the codes, [torch.uint1, ..., torch.uint8]
128-
Note that codes is stored in torch.uint8, this is just addtional information for dequantize op
129-
block_size (List[int]): a slice of elements with shape block_size will share the same lookup table
130-
only support (-1, ..., group_size) right now (all preceding dimensions has to match input)
134+
General shape: (d0, d1, d2, ..., dN)
135+
Simple example shape: (N, K)
136+
codebook (torch.Tensor): Codebook tensor used for quantization
137+
General shape: (d0 // block_size[0], ..., dN // block_size[N], 2**nbits, vec_dim), where vec_dim = 1 for scalar look up values
138+
Simple example shape: (1, group_size, 2 ** nbits, 1) for scalar look up values, with 1 table per group_size columns
139+
nbits: int: number of bits for the quantization
140+
block_size (List[int]): a slice of elements with shape block_size will share the same lookup table.
141+
If block_size[i] == -1, then the entire dimension is used.
131142
output_dtype (torch.dtype): dtype for the output tensor.
132143
133144
Returns:
@@ -140,37 +151,67 @@ def dequantize_codebook(
140151
torch.bfloat16,
141152
], f"Unsupported output dtype: {output_dtype}"
142153

143-
assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8]
154+
assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}"
144155

145-
assert len(block_size) == codes.ndim
156+
assert len(block_size) == codes.dim()
146157
block_size = block_size.copy()
147-
for i in range(codes.ndim - 1):
148-
assert block_size[i] == -1 or block_size[i] == codes.shape[i], (
149-
f"{block_size} not supported"
158+
for i in range(len(block_size)):
159+
if block_size[i] == -1:
160+
block_size[i] = codes.shape[i]
161+
assert block_size[i] >= 1 and codes.shape[i] % block_size[i] == 0, (
162+
"block_size[i] must divide codes.shape[i]"
150163
)
151164

152-
group_size = block_size[-1]
153-
if group_size == -1:
154-
group_size = codes.shape[-1]
165+
assert codebook.dim() == codes.dim() + 2
166+
codebook_shape = codebook.shape
167+
vec_dim = codebook_shape[-1]
168+
quant_levels = 2**nbits
155169

156-
assert codes.shape[-1] % group_size == 0
157-
K = codes.shape[-1]
158-
num_lut = K // group_size
159-
# (N, K)
160-
original_shape = codes.shape
170+
# Check that last two dimensions of codebook are [quant_levels, vec_dim]
171+
assert codebook_shape[-2] == quant_levels, "Codebook shape mismatch with nbits"
161172

162-
# reshape to (N, num_lut, group_size)
163-
codes = codes.reshape(codes.shape[0], num_lut, group_size)
164-
dequant = torch.zeros_like(codes, dtype=output_dtype)
173+
# Compute shape of lookup group indices from codes shape and block size
174+
code_shape = codes.shape
175+
ndim = codes.ndim
176+
assert len(block_size) == ndim, "block_size must match dimensionality of codes"
165177

166-
# do lookup for each lookup table
167-
# dequant shape: (N, num_lut, group_size)
168-
# codebook shape: (num_lut, 2 ** nbits)
169-
# codes shape: (N, num_lut, group_size)
170-
for i in range(num_lut):
171-
# dequant[:, i, :]: (N, group_size)
172-
# using squeeze to remove the training dim 1s after the lookup
173-
dequant[:, i, :] = codebook[i][codes[:, i, :]].squeeze()
178+
# Compute which codebook slice to use for each element
179+
group_indices = []
180+
for i in range(ndim):
181+
assert block_size[i] >= 1 and code_shape[i] % block_size[i] == 0, (
182+
f"dimension {code_shape[i]} not divisible by block size {block_size[i]}"
183+
)
174184

175-
dequant = dequant.reshape(*original_shape)
176-
return dequant.to(output_dtype)
185+
# Index of block
186+
idx = (
187+
torch.arange(code_shape[i], device=codes.device) // block_size[i]
188+
) # shape (di,)
189+
190+
# Reshape idx to broadcast along all other dims
191+
shape = [1] * ndim
192+
shape[i] = code_shape[i]
193+
idx = idx.view(*shape) # shape (1, ..., 1, di, 1, ..., 1)
194+
idx = idx.expand(code_shape) # shape (d0, ..., dN)
195+
group_indices.append(idx)
196+
197+
# Stack the broadcasted group indices
198+
# group_index_tensor at (i0, i1, ..., iN) is the gives the group indices (g0, ..., gN)
199+
# for the element at (i0, i1, ..., iN) in the original code
200+
# If code.shape = (d1, d2, d3), then group_index_tensor.shape = (d1, d2, d3, 3)
201+
group_index_tensor = torch.stack(
202+
group_indices, dim=-1
203+
) # shape (d0, d1, ..., dN, ndim)
204+
205+
# Flatten everything to index efficiently
206+
flat_codes = codes.reshape(-1) # shape (numel,)
207+
flat_groups = group_index_tensor.reshape(-1, ndim) # (numel, ndim)
208+
209+
# Compute dequantized values via indexing
210+
# index into codebook with (*group_index, code_index, :)
211+
gathered = codebook[(*flat_groups.T, flat_codes)] # shape (numel, vec_dim)
212+
dequant = gathered.reshape(*code_shape, vec_dim)
213+
214+
if vec_dim == 1:
215+
dequant = dequant.squeeze(-1)
216+
217+
return dequant.to(dtype=output_dtype)

torchao/prototype/quantization/codebook_coreml/codebook_quantized_tensor.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
choose_qparams_and_quantize_codebook_coreml,
1313
dequantize_codebook,
1414
)
15+
from torchao.quantization.quant_primitives import (
16+
_DTYPE_TO_BIT_WIDTH,
17+
)
1518
from torchao.utils import TorchAOBaseTensor
1619

1720
aten = torch.ops.aten
@@ -95,7 +98,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
9598
return dequantize_codebook(
9699
codes,
97100
self.codebook,
98-
self.code_dtype,
101+
_DTYPE_TO_BIT_WIDTH[self.code_dtype],
99102
self.block_size,
100103
output_dtype=output_dtype,
101104
)
@@ -174,6 +177,17 @@ def _(func, types, args, kwargs):
174177
return func(input_tensor, weight_tensor, bias)
175178

176179

180+
@implements([torch.nn.functional.embedding, aten.embedding.default])
181+
def _(func, types, args, kwargs):
182+
assert len(args) == 2
183+
indices, weight_tensor = (
184+
args[0],
185+
args[1],
186+
)
187+
weight_tensor = weight_tensor.dequantize()
188+
return func(indices, weight_tensor, **kwargs)
189+
190+
177191
@implements([aten.detach.default, aten.alias.default])
178192
def _(func, types, args, kwargs):
179193
return return_and_correct_aliasing(

0 commit comments

Comments
 (0)