|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +from typing import List, Optional, Tuple |
| 7 | + |
| 8 | +import torch |
| 9 | + |
| 10 | +from torchao.quantization.quant_primitives import ( |
| 11 | + _DTYPE_TO_BIT_WIDTH, |
| 12 | + _SUB_BYTE_UINT_BOUNDS, |
| 13 | +) |
| 14 | +from torchao.utils import _register_custom_op |
| 15 | + |
| 16 | +quant_lib = torch.library.Library("quant", "FRAGMENT") |
| 17 | +register_custom_op = _register_custom_op(quant_lib) |
| 18 | + |
| 19 | + |
| 20 | +# wrapper around coreml util: https://github.com/apple/coremltools/blob/1c0e5cb1c1e3ab759af107b54f2be18b7c03f8aa/coremltools/models/neural_network/quantization_utils.py#L363 |
| 21 | +@torch.no_grad |
| 22 | +@register_custom_op |
| 23 | +def choose_qparams_and_quantize_codebook_coreml( |
| 24 | + input_tensor: torch.Tensor, |
| 25 | + code_dtype: torch.dtype, |
| 26 | + block_size: List[int], |
| 27 | + force_kmeans1d: bool = False, |
| 28 | + cluster_dim: int = 1, |
| 29 | + vector_axis: Optional[int] = None, |
| 30 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 31 | + """ |
| 32 | + Initialize the codebook using k-means clustering on blocks of the input tensor. |
| 33 | +
|
| 34 | + Args: |
| 35 | + input_tensor (torch.Tensor): The input tensor to be quantized. |
| 36 | + code_dtype (torch.dtype): The dtype for the codes. [torch.uint1, ..., torch.uint8] |
| 37 | + block_size (List[int]): the size for how many elements of last dimension of input_tensor |
| 38 | + belong to the same group and should share the same lookup table. let's say original |
| 39 | + shape is (N, K), and block_size of (N, group_size) or (-1, group_size), |
| 40 | + then the slice of (N, group_size) elements should use the same lookup |
| 41 | + table, and there will be (K // group_size) lookup tables |
| 42 | + force_kmeans1d (bool): Use kmeans1d regardless of number of weights |
| 43 | + cluster_dim (int): this means the size of the vector for vector lookup table quantization |
| 44 | + e.g. when cluster_dim is 4, instead of quantizing each scalar value one by one, we quantize |
| 45 | + the tensor in a unit of 4 element vectors, a vector of original tensor will be mapped to |
| 46 | + a vector in the codebook (lookup table) based on the indices. |
| 47 | + vector_axis (Optional[int]): used in vector quantization, see more docs in https://github.com/apple/coremltools/blob/1c0e5cb1c1e3ab759af107b54f2be18b7c03f8aa/coremltools/optimize/_utils.py#L371 |
| 48 | +
|
| 49 | + Returns: |
| 50 | + Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8) |
| 51 | + """ |
| 52 | + assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8] |
| 53 | + assert len(block_size) == input_tensor.ndim |
| 54 | + block_size = block_size.copy() |
| 55 | + for i in range(input_tensor.ndim - 1): |
| 56 | + assert block_size[i] == -1 or block_size[i] == input_tensor.shape[i], ( |
| 57 | + f"{block_size} not supported" |
| 58 | + ) |
| 59 | + |
| 60 | + group_size = block_size[-1] |
| 61 | + if group_size == -1: |
| 62 | + group_size = input_tensor.shape[-1] |
| 63 | + |
| 64 | + assert input_tensor.shape[-1] % group_size == 0 |
| 65 | + assert input_tensor.ndim == 2 |
| 66 | + assert cluster_dim == 1, ( |
| 67 | + f"only cluster_dim == 1 is supported right now, got {cluster_dim}" |
| 68 | + ) |
| 69 | + |
| 70 | + # for converting to numpy |
| 71 | + input_tensor = input_tensor.detach() |
| 72 | + # (N, K) |
| 73 | + original_shape = input_tensor.shape |
| 74 | + # (K // group_size) |
| 75 | + num_lut = input_tensor.shape[1] // group_size |
| 76 | + |
| 77 | + # reshape to (N, K // group_size, group_size) |
| 78 | + input_tensor = input_tensor.reshape(input_tensor.shape[0], num_lut, group_size) |
| 79 | + from coremltools.models.neural_network.quantization_utils import ( |
| 80 | + _get_kmeans_lookup_table_and_weight, |
| 81 | + ) |
| 82 | + |
| 83 | + nbits = _DTYPE_TO_BIT_WIDTH[code_dtype] |
| 84 | + if nbits > 8: |
| 85 | + print(f"Requested nbits: {nbits}, rewriting to 8 bits to reduce the size") |
| 86 | + nbits = 8 |
| 87 | + |
| 88 | + res_lut = [] |
| 89 | + # each res_w[:, i, :] will use the same lookup table |
| 90 | + # res_w: (N, K // group_size, group_size) |
| 91 | + res_w = torch.zeros_like(input_tensor, dtype=torch.uint8) |
| 92 | + for i in range(num_lut): |
| 93 | + # lut: (2**nbits, 1) |
| 94 | + # w: (N * group_size) |
| 95 | + lut, w = _get_kmeans_lookup_table_and_weight( |
| 96 | + nbits, input_tensor[:, i, :], force_kmeans1d, cluster_dim, vector_axis |
| 97 | + ) |
| 98 | + res_lut.append(torch.from_numpy(lut)) |
| 99 | + res_w[:, i, :] = torch.from_numpy(w.reshape(input_tensor.shape[0], group_size)) |
| 100 | + |
| 101 | + # directly stack all lookup tables along dim 0 |
| 102 | + # res_lut: (K // group_size, 2 ** nbits) |
| 103 | + res_lut = torch.stack(res_lut, dim=0) |
| 104 | + |
| 105 | + # reshape back to (N, K) |
| 106 | + res_w = res_w.reshape(*original_shape) |
| 107 | + |
| 108 | + return res_lut, res_w |
| 109 | + |
| 110 | + |
| 111 | +@register_custom_op |
| 112 | +def dequantize_codebook( |
| 113 | + codes: torch.Tensor, |
| 114 | + codebook: torch.Tensor, |
| 115 | + code_dtype: torch.dtype, |
| 116 | + block_size: List[int], |
| 117 | + output_dtype: torch.dtype = torch.float32, |
| 118 | +) -> torch.Tensor: |
| 119 | + """ |
| 120 | + Reconstructs the original tensor from codes and the codebook. |
| 121 | +
|
| 122 | + Args: |
| 123 | + codes (torch.Tensor): Indices of codebook entries for each element |
| 124 | + shape (N, K) for scalar quantization |
| 125 | + codebook (torch.Tensor): Codebook tensor used for quantization, |
| 126 | + shape (K // group_size, 2 ** nbits) where K is the dim 1 shape of input |
| 127 | + code_dtype (torch.dtype): The logical dtype for the codes, [torch.uint1, ..., torch.uint8] |
| 128 | + Note that codes is stored in torch.uint8, this is just addtional information for dequantize op |
| 129 | + block_size (List[int]): a slice of elements with shape block_size will share the same lookup table |
| 130 | + only support (-1, ..., group_size) right now (all preceding dimensions has to match input) |
| 131 | + output_dtype (torch.dtype): dtype for the output tensor. |
| 132 | +
|
| 133 | + Returns: |
| 134 | + dequant (torch.Tensor): Reconstructed tensor, shape (N, K) |
| 135 | +
|
| 136 | + """ |
| 137 | + assert output_dtype in [ |
| 138 | + torch.float32, |
| 139 | + torch.float16, |
| 140 | + torch.bfloat16, |
| 141 | + ], f"Unsupported output dtype: {output_dtype}" |
| 142 | + |
| 143 | + assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8] |
| 144 | + |
| 145 | + assert len(block_size) == codes.ndim |
| 146 | + block_size = block_size.copy() |
| 147 | + for i in range(codes.ndim - 1): |
| 148 | + assert block_size[i] == -1 or block_size[i] == codes.shape[i], ( |
| 149 | + f"{block_size} not supported" |
| 150 | + ) |
| 151 | + |
| 152 | + group_size = block_size[-1] |
| 153 | + if group_size == -1: |
| 154 | + group_size = codes.shape[-1] |
| 155 | + |
| 156 | + assert codes.shape[-1] % group_size == 0 |
| 157 | + K = codes.shape[-1] |
| 158 | + num_lut = K // group_size |
| 159 | + # (N, K) |
| 160 | + original_shape = codes.shape |
| 161 | + |
| 162 | + # reshape to (N, num_lut, group_size) |
| 163 | + codes = codes.reshape(codes.shape[0], num_lut, group_size) |
| 164 | + dequant = torch.zeros_like(codes, dtype=output_dtype) |
| 165 | + |
| 166 | + # do lookup for each lookup table |
| 167 | + # dequant shape: (N, num_lut, group_size) |
| 168 | + # codebook shape: (num_lut, 2 ** nbits) |
| 169 | + # codes shape: (N, num_lut, group_size) |
| 170 | + for i in range(num_lut): |
| 171 | + # dequant[:, i, :]: (N, group_size) |
| 172 | + # using squeeze to remove the training dim 1s after the lookup |
| 173 | + dequant[:, i, :] = codebook[i][codes[:, i, :]].squeeze() |
| 174 | + |
| 175 | + dequant = dequant.reshape(*original_shape) |
| 176 | + return dequant.to(output_dtype) |
0 commit comments