|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | + |
| 8 | +from typing import List |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +from torchao.quantization.quant_primitives import ( |
| 13 | + MappingType, |
| 14 | + choose_qparams_affine, |
| 15 | + quantize_affine, |
| 16 | +) |
| 17 | +from torchao.utils import TorchAOBaseTensor |
| 18 | + |
| 19 | +__all__ = [ |
| 20 | + "Int4MarlinSparseTensor", |
| 21 | +] |
| 22 | + |
| 23 | +aten = torch.ops.aten |
| 24 | + |
| 25 | + |
| 26 | +class Int4MarlinSparseTensor(TorchAOBaseTensor): |
| 27 | + tensor_data_names = ["qdata", "scale", "zero_point", "meta"] |
| 28 | + tensor_attribute_names = ["block_size", "num_bits", "shape"] |
| 29 | + |
| 30 | + def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape): |
| 31 | + kwargs = {} |
| 32 | + kwargs["device"] = qdata.device |
| 33 | + kwargs["dtype"] = scale.dtype |
| 34 | + kwargs["requires_grad"] = False |
| 35 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 36 | + |
| 37 | + def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape): |
| 38 | + self.qdata = qdata |
| 39 | + self.scale = scale |
| 40 | + self.zero_point = zero_point |
| 41 | + self.meta = meta |
| 42 | + self.block_size = block_size |
| 43 | + self.num_bits = num_bits |
| 44 | + |
| 45 | + def _quantization_type(self): |
| 46 | + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" |
| 47 | + |
| 48 | + @classmethod |
| 49 | + def from_hp( |
| 50 | + cls, |
| 51 | + w: torch.Tensor, |
| 52 | + block_size: List[int], |
| 53 | + ): |
| 54 | + from torchao.sparsity.marlin import ( |
| 55 | + const, |
| 56 | + inject_24, # avoid circular import |
| 57 | + pack_to_marlin_24, |
| 58 | + ) |
| 59 | + |
| 60 | + """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. |
| 61 | + - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format |
| 62 | + - 2º: tensor is injected with 2:4 sparsity |
| 63 | + - 3º: transposes it again because the quantization process will compute the scales for dim=-1 |
| 64 | + """ |
| 65 | + |
| 66 | + w_t = w.t() |
| 67 | + w_24, _ = inject_24(w_t, *w_t.shape) |
| 68 | + preprocessed_w = w_24.t() |
| 69 | + |
| 70 | + assert block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1], ( |
| 71 | + f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}" |
| 72 | + ) |
| 73 | + |
| 74 | + quant_min = 0 |
| 75 | + quant_max = 15 |
| 76 | + target_dtype = torch.int32 |
| 77 | + |
| 78 | + scale, zero_point = choose_qparams_affine( |
| 79 | + input=preprocessed_w, |
| 80 | + mapping_type=MappingType.SYMMETRIC, |
| 81 | + block_size=block_size, |
| 82 | + target_dtype=target_dtype, |
| 83 | + quant_min=quant_min, |
| 84 | + quant_max=quant_max, |
| 85 | + eps=1e-6, |
| 86 | + ) |
| 87 | + |
| 88 | + wq = quantize_affine( |
| 89 | + input=preprocessed_w, |
| 90 | + block_size=block_size, |
| 91 | + scale=scale, |
| 92 | + zero_point=zero_point, |
| 93 | + output_dtype=target_dtype, |
| 94 | + quant_min=quant_min, |
| 95 | + quant_max=quant_max, |
| 96 | + ) |
| 97 | + |
| 98 | + scale = scale.to(w.dtype) |
| 99 | + zero_point = zero_point.to(w.dtype) |
| 100 | + |
| 101 | + # Linear layers are (in_features, out_features) but the qdata that is reaching this point |
| 102 | + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. |
| 103 | + q_w_24 = wq.t() |
| 104 | + # addressing the case when scale has dimension 1, happens when |
| 105 | + # weight_shape[-1] == group_size == 128 |
| 106 | + if scale.ndim == 1: |
| 107 | + scale = scale.reshape(scale.shape[0], -1) |
| 108 | + |
| 109 | + scale_t = scale.t() |
| 110 | + |
| 111 | + if not torch.cuda.get_device_capability()[0] >= 8: |
| 112 | + raise ValueError( |
| 113 | + f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." |
| 114 | + ) |
| 115 | + |
| 116 | + if q_w_24.dtype != torch.int32: |
| 117 | + raise ValueError("Only `torch.int32` weights are supported.") |
| 118 | + |
| 119 | + in_features, out_features = q_w_24.shape |
| 120 | + if in_features % 128 != 0 or out_features != 256 == 0: |
| 121 | + raise ValueError( |
| 122 | + "`in_features` must be divisible by 64 and `out_features` by 256." |
| 123 | + ) |
| 124 | + |
| 125 | + # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8 |
| 126 | + # will require a bit more work to get our current quantization flow to work with it. |
| 127 | + # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main |
| 128 | + num_bits = 4 if torch.max(q_w_24) < 16 else -1 |
| 129 | + if num_bits not in [4]: |
| 130 | + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") |
| 131 | + |
| 132 | + group_size = in_features // scale_t.shape[0] |
| 133 | + if group_size == 0: |
| 134 | + group_size = in_features |
| 135 | + assert group_size <= in_features, ( |
| 136 | + "Group size must be less than or equal to in_features." |
| 137 | + ) |
| 138 | + |
| 139 | + if group_size not in const.SUPPORTED_GROUP_SIZES: |
| 140 | + raise ValueError( |
| 141 | + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." |
| 142 | + ) |
| 143 | + |
| 144 | + # Compress quantized weight to marlin 2:4 format |
| 145 | + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( |
| 146 | + q_w_24, scale_t, num_bits, group_size |
| 147 | + ) |
| 148 | + |
| 149 | + return cls( |
| 150 | + qdata=marlin_24_q_w_comp, |
| 151 | + scale=marlin_24_s, |
| 152 | + zero_point=zero_point, |
| 153 | + meta=meta, |
| 154 | + block_size=group_size, |
| 155 | + shape=q_w_24.shape, |
| 156 | + num_bits=num_bits, |
| 157 | + ) |
| 158 | + |
| 159 | + |
| 160 | +implements = Int4MarlinSparseTensor.implements |
| 161 | + |
| 162 | + |
| 163 | +@implements([torch.nn.functional.linear, aten.linear.default]) |
| 164 | +def _(func, types, args, kwargs): |
| 165 | + from torchao.ops import marlin_24_gemm |
| 166 | + from torchao.sparsity.marlin import marlin_24_workspace |
| 167 | + |
| 168 | + input_tensor, weight_tensor, bias = ( |
| 169 | + args[0], |
| 170 | + args[1], |
| 171 | + args[2] if len(args) > 2 else None, |
| 172 | + ) |
| 173 | + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" |
| 174 | + assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" |
| 175 | + assert weight_tensor.zero_point.is_contiguous(), ( |
| 176 | + "Expected zero_point to be contiguous" |
| 177 | + ) |
| 178 | + |
| 179 | + sparse_w_int4 = weight_tensor.qdata |
| 180 | + scale = weight_tensor.scale |
| 181 | + meta = weight_tensor.meta |
| 182 | + original_shape = weight_tensor.shape |
| 183 | + num_bits = weight_tensor.num_bits |
| 184 | + |
| 185 | + # Folds batch dimension into the first dimension |
| 186 | + input_2d = input_tensor.view(-1, input_tensor.shape[-1]) |
| 187 | + |
| 188 | + size_m = input_2d.shape[0] |
| 189 | + size_n = scale.shape[1] |
| 190 | + size_k = input_2d.shape[1] |
| 191 | + workspace_24 = marlin_24_workspace(original_shape[1]) |
| 192 | + |
| 193 | + out = marlin_24_gemm( |
| 194 | + input_2d, |
| 195 | + sparse_w_int4, |
| 196 | + meta, |
| 197 | + scale, |
| 198 | + workspace_24, |
| 199 | + num_bits, |
| 200 | + size_m, |
| 201 | + size_n, |
| 202 | + size_k, |
| 203 | + ) |
| 204 | + |
| 205 | + # Unfold the batch dimension |
| 206 | + out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],)) |
| 207 | + |
| 208 | + if bias is not None: |
| 209 | + out += bias.to(out.dtype) |
| 210 | + return out |
| 211 | + |
| 212 | + |
| 213 | +Int4MarlinSparseTensor.__module__ = "torchao.quantization" |
| 214 | + |
| 215 | +# Allow a model with Int4MarlinSparseTensor weights to be loaded with `weights_only=True` |
| 216 | +torch.serialization.add_safe_globals([Int4MarlinSparseTensor]) |
0 commit comments