@@ -34,11 +34,12 @@ def choose_qparams_and_quantize_codebook_coreml(
34
34
Args:
35
35
input_tensor (torch.Tensor): The input tensor to be quantized.
36
36
code_dtype (torch.dtype): The dtype for the codes. [torch.uint1, ..., torch.uint8]
37
- block_size (List[int]): the size for how many elements of last dimension of input_tensor
38
- belong to the same group and should share the same lookup table. let's say original
39
- shape is (N, K), and block_size of (N, group_size) or (-1, group_size),
40
- then the slice of (N, group_size) elements should use the same lookup
41
- table, and there will be (K // group_size) lookup tables
37
+ block_size (List[int]): block sizes for how many elements in each dimension share
38
+ the same lookup table (len(block_size) == input_tensor.dim())
39
+ Each dimension of input_tensor must be divisible by the corresponding element of block_size
40
+ Look up tables are indexed by {(di // bi) for i in input_tensor.dim()}
41
+ For example, if the input tensor has shape (N, K), and block_size is (N, group_size), this means
42
+ there is a lookup table for group_size columns, i.e., (K // group_size) total look up tables
42
43
force_kmeans1d (bool): Use kmeans1d regardless of number of weights
43
44
cluster_dim (int): this means the size of the vector for vector lookup table quantization
44
45
e.g. when cluster_dim is 4, instead of quantizing each scalar value one by one, we quantize
@@ -48,43 +49,45 @@ def choose_qparams_and_quantize_codebook_coreml(
48
49
49
50
Returns:
50
51
Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8)
52
+ The LUT table has dimension (g0, .., g(N-1), 2**nbits, vec_dim), where:
53
+ * The first N dimensions index over the different tables (gi = input_tensor.shape[i] // block_size[i] in each dimension)
54
+ * The N + 1 dimension indexes over the nbit indices (2 ** nbits)
55
+ * The N + 2 dimension indexes over the look up values (shape = 1 for scalar)
51
56
"""
52
57
assert code_dtype in list (_SUB_BYTE_UINT_BOUNDS .keys ()) + [torch .uint8 ]
53
- assert len (block_size ) == input_tensor .ndim
58
+ nbits = _DTYPE_TO_BIT_WIDTH [code_dtype ]
59
+ assert nbits >= 1 and nbits <= 8 , f"nbits must be in [1, 8], got { nbits } "
60
+
61
+ assert len (block_size ) == input_tensor .dim ()
54
62
block_size = block_size .copy ()
55
- for i in range (input_tensor .ndim - 1 ):
56
- assert block_size [i ] == - 1 or block_size [i ] == input_tensor .shape [i ], (
57
- f"{ block_size } not supported"
63
+ for i in range (len (block_size )):
64
+ if block_size [i ] == - 1 :
65
+ block_size [i ] = input_tensor .shape [i ]
66
+ assert block_size [i ] >= 1 and input_tensor .shape [i ] % block_size [i ] == 0 , (
67
+ "block_size[i] must divide input_tensor.shape[i]"
58
68
)
59
69
60
- group_size = block_size [- 1 ]
61
- if group_size == - 1 :
62
- group_size = input_tensor .shape [- 1 ]
63
-
64
- assert input_tensor .shape [- 1 ] % group_size == 0
65
- assert input_tensor .ndim == 2
70
+ assert input_tensor .dim () == 2 , "Currently only rank 2 tensors are supported"
71
+ assert block_size [0 ] == input_tensor .shape [0 ], (
72
+ "Currently only support per-grouped channel granularity"
73
+ )
66
74
assert cluster_dim == 1 , (
67
75
f"only cluster_dim == 1 is supported right now, got { cluster_dim } "
68
76
)
69
77
78
+ num_lut = input_tensor .shape [1 ] // block_size [1 ]
79
+ group_size = block_size [1 ]
80
+
70
81
# for converting to numpy
71
82
input_tensor = input_tensor .detach ()
72
- # (N, K)
73
83
original_shape = input_tensor .shape
74
- # (K // group_size)
75
- num_lut = input_tensor .shape [1 ] // group_size
76
84
77
85
# reshape to (N, K // group_size, group_size)
78
86
input_tensor = input_tensor .reshape (input_tensor .shape [0 ], num_lut , group_size )
79
87
from coremltools .models .neural_network .quantization_utils import (
80
88
_get_kmeans_lookup_table_and_weight ,
81
89
)
82
90
83
- nbits = _DTYPE_TO_BIT_WIDTH [code_dtype ]
84
- if nbits > 8 :
85
- print (f"Requested nbits: { nbits } , rewriting to 8 bits to reduce the size" )
86
- nbits = 8
87
-
88
91
res_lut = []
89
92
# each res_w[:, i, :] will use the same lookup table
90
93
# res_w: (N, K // group_size, group_size)
@@ -102,6 +105,13 @@ def choose_qparams_and_quantize_codebook_coreml(
102
105
# res_lut: (K // group_size, 2 ** nbits)
103
106
res_lut = torch .stack (res_lut , dim = 0 )
104
107
108
+ # The final LUT should have dimension equal to input_tensor.dim() + 2
109
+ # The first input_tensor.dim() dimensions index over the tables,
110
+ # input_tensor.dim() + 1 indexes over the nbit indices
111
+ # input_tensor.dim() + 2 are the look up values (shape = 1 for scalar)
112
+ # res_lut: (N, K // group_size, 2 ** nbits, group_size)
113
+ res_lut = res_lut .reshape (1 , num_lut , 2 ** nbits , 1 )
114
+
105
115
# reshape back to (N, K)
106
116
res_w = res_w .reshape (* original_shape )
107
117
@@ -112,7 +122,7 @@ def choose_qparams_and_quantize_codebook_coreml(
112
122
def dequantize_codebook (
113
123
codes : torch .Tensor ,
114
124
codebook : torch .Tensor ,
115
- code_dtype : torch . dtype ,
125
+ nbits : int ,
116
126
block_size : List [int ],
117
127
output_dtype : torch .dtype = torch .float32 ,
118
128
) -> torch .Tensor :
@@ -121,13 +131,14 @@ def dequantize_codebook(
121
131
122
132
Args:
123
133
codes (torch.Tensor): Indices of codebook entries for each element
124
- shape (N, K) for scalar quantization
125
- codebook (torch.Tensor): Codebook tensor used for quantization,
126
- shape (K // group_size, 2 ** nbits) where K is the dim 1 shape of input
127
- code_dtype (torch.dtype): The logical dtype for the codes, [torch.uint1, ..., torch.uint8]
128
- Note that codes is stored in torch.uint8, this is just addtional information for dequantize op
129
- block_size (List[int]): a slice of elements with shape block_size will share the same lookup table
130
- only support (-1, ..., group_size) right now (all preceding dimensions has to match input)
134
+ General shape: (d0, d1, d2, ..., dN)
135
+ Simple example shape: (N, K)
136
+ codebook (torch.Tensor): Codebook tensor used for quantization
137
+ General shape: (d0 // block_size[0], ..., dN // block_size[N], 2**nbits, vec_dim), where vec_dim = 1 for scalar look up values
138
+ Simple example shape: (1, group_size, 2 ** nbits, 1) for scalar look up values, with 1 table per group_size columns
139
+ nbits: int: number of bits for the quantization
140
+ block_size (List[int]): a slice of elements with shape block_size will share the same lookup table.
141
+ If block_size[i] == -1, then the entire dimension is used.
131
142
output_dtype (torch.dtype): dtype for the output tensor.
132
143
133
144
Returns:
@@ -140,37 +151,67 @@ def dequantize_codebook(
140
151
torch .bfloat16 ,
141
152
], f"Unsupported output dtype: { output_dtype } "
142
153
143
- assert code_dtype in list ( _SUB_BYTE_UINT_BOUNDS . keys ()) + [ torch . uint8 ]
154
+ assert nbits >= 1 and nbits <= 8 , f"nbits must be in [1, 8], got { nbits } "
144
155
145
- assert len (block_size ) == codes .ndim
156
+ assert len (block_size ) == codes .dim ()
146
157
block_size = block_size .copy ()
147
- for i in range (codes .ndim - 1 ):
148
- assert block_size [i ] == - 1 or block_size [i ] == codes .shape [i ], (
149
- f"{ block_size } not supported"
158
+ for i in range (len (block_size )):
159
+ if block_size [i ] == - 1 :
160
+ block_size [i ] = codes .shape [i ]
161
+ assert block_size [i ] >= 1 and codes .shape [i ] % block_size [i ] == 0 , (
162
+ "block_size[i] must divide codes.shape[i]"
150
163
)
151
164
152
- group_size = block_size [- 1 ]
153
- if group_size == - 1 :
154
- group_size = codes .shape [- 1 ]
165
+ assert codebook .dim () == codes .dim () + 2
166
+ codebook_shape = codebook .shape
167
+ vec_dim = codebook_shape [- 1 ]
168
+ quant_levels = 2 ** nbits
155
169
156
- assert codes .shape [- 1 ] % group_size == 0
157
- K = codes .shape [- 1 ]
158
- num_lut = K // group_size
159
- # (N, K)
160
- original_shape = codes .shape
170
+ # Check that last two dimensions of codebook are [quant_levels, vec_dim]
171
+ assert codebook_shape [- 2 ] == quant_levels , "Codebook shape mismatch with nbits"
161
172
162
- # reshape to (N, num_lut, group_size)
163
- codes = codes .reshape (codes .shape [0 ], num_lut , group_size )
164
- dequant = torch .zeros_like (codes , dtype = output_dtype )
173
+ # Compute shape of lookup group indices from codes shape and block size
174
+ code_shape = codes .shape
175
+ ndim = codes .ndim
176
+ assert len (block_size ) == ndim , "block_size must match dimensionality of codes"
165
177
166
- # do lookup for each lookup table
167
- # dequant shape: (N, num_lut, group_size)
168
- # codebook shape: (num_lut, 2 ** nbits)
169
- # codes shape: (N, num_lut, group_size)
170
- for i in range (num_lut ):
171
- # dequant[:, i, :]: (N, group_size)
172
- # using squeeze to remove the training dim 1s after the lookup
173
- dequant [:, i , :] = codebook [i ][codes [:, i , :]].squeeze ()
178
+ # Compute which codebook slice to use for each element
179
+ group_indices = []
180
+ for i in range (ndim ):
181
+ assert block_size [i ] >= 1 and code_shape [i ] % block_size [i ] == 0 , (
182
+ f"dimension { code_shape [i ]} not divisible by block size { block_size [i ]} "
183
+ )
174
184
175
- dequant = dequant .reshape (* original_shape )
176
- return dequant .to (output_dtype )
185
+ # Index of block
186
+ idx = (
187
+ torch .arange (code_shape [i ], device = codes .device ) // block_size [i ]
188
+ ) # shape (di,)
189
+
190
+ # Reshape idx to broadcast along all other dims
191
+ shape = [1 ] * ndim
192
+ shape [i ] = code_shape [i ]
193
+ idx = idx .view (* shape ) # shape (1, ..., 1, di, 1, ..., 1)
194
+ idx = idx .expand (code_shape ) # shape (d0, ..., dN)
195
+ group_indices .append (idx )
196
+
197
+ # Stack the broadcasted group indices
198
+ # group_index_tensor at (i0, i1, ..., iN) is the gives the group indices (g0, ..., gN)
199
+ # for the element at (i0, i1, ..., iN) in the original code
200
+ # If code.shape = (d1, d2, d3), then group_index_tensor.shape = (d1, d2, d3, 3)
201
+ group_index_tensor = torch .stack (
202
+ group_indices , dim = - 1
203
+ ) # shape (d0, d1, ..., dN, ndim)
204
+
205
+ # Flatten everything to index efficiently
206
+ flat_codes = codes .reshape (- 1 ) # shape (numel,)
207
+ flat_groups = group_index_tensor .reshape (- 1 , ndim ) # (numel, ndim)
208
+
209
+ # Compute dequantized values via indexing
210
+ # index into codebook with (*group_index, code_index, :)
211
+ gathered = codebook [(* flat_groups .T , flat_codes )] # shape (numel, vec_dim)
212
+ dequant = gathered .reshape (* code_shape , vec_dim )
213
+
214
+ if vec_dim == 1 :
215
+ dequant = dequant .squeeze (- 1 )
216
+
217
+ return dequant .to (dtype = output_dtype )
0 commit comments