|
| 1 | +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, |
| 10 | +# software distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from dataclasses import dataclass |
| 16 | +from typing import Dict, List, Tuple, Union |
| 17 | + |
| 18 | +import torch |
| 19 | +from compressed_tensors.compressors.base import BaseCompressor |
| 20 | +from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor |
| 21 | +from compressed_tensors.config import CompressionFormat, SparsityStructure |
| 22 | +from compressed_tensors.quantization import FP8_DTYPE |
| 23 | +from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks |
| 24 | +from torch import Tensor |
| 25 | + |
| 26 | + |
| 27 | +__all__ = [ |
| 28 | + "Sparse24BitMaskCompressor", |
| 29 | + "Sparse24BitMaskTensor", |
| 30 | + "sparse24_bitmask_compress", |
| 31 | + "sparse24_bitmask_decompress", |
| 32 | + "get_24_bytemasks", |
| 33 | +] |
| 34 | + |
| 35 | + |
| 36 | +@BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value) |
| 37 | +class Sparse24BitMaskCompressor(BaseSparseCompressor): |
| 38 | + """ |
| 39 | + Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d |
| 40 | + values tensor, with their locations stored in a 2d bitmask |
| 41 | + """ |
| 42 | + |
| 43 | + COMPRESSION_PARAM_NAMES = [ |
| 44 | + "shape", |
| 45 | + "compressed", |
| 46 | + "bitmask", |
| 47 | + ] |
| 48 | + |
| 49 | + def compress_weight(self, name, value): |
| 50 | + bitmask_tensor = Sparse24BitMaskTensor.from_dense( |
| 51 | + value, self.config.sparsity_structure |
| 52 | + ) |
| 53 | + bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") |
| 54 | + return bitmask_dict |
| 55 | + |
| 56 | + def decompress_weight(self, weight_data): |
| 57 | + data = Sparse24BitMaskTensor.from_compressed_data(**weight_data) |
| 58 | + decompressed = data.decompress() |
| 59 | + return decompressed |
| 60 | + |
| 61 | + |
| 62 | +@dataclass |
| 63 | +class Sparse24BitMaskTensor: |
| 64 | + """ |
| 65 | + Owns compressions and decompression for a single 2:4 sparse |
| 66 | + bitmask compressed tensor. |
| 67 | +
|
| 68 | + :param shape: shape of dense tensor |
| 69 | + :param compressed: 2d tensor of non-zero values |
| 70 | + :param bitmask: 2d bitmask of non-zero values |
| 71 | + """ |
| 72 | + |
| 73 | + shape: List[int] |
| 74 | + compressed: Tensor |
| 75 | + bitmask: Tensor |
| 76 | + |
| 77 | + @staticmethod |
| 78 | + def from_dense( |
| 79 | + tensor: Tensor, |
| 80 | + sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR, |
| 81 | + ) -> "Sparse24BitMaskTensor": |
| 82 | + """ |
| 83 | + :param tensor: dense tensor to compress |
| 84 | + :return: instantiated compressed tensor |
| 85 | + """ |
| 86 | + shape = list(tensor.shape) |
| 87 | + compressed, bitmask = sparse24_bitmask_compress( |
| 88 | + tensor.cpu(), sparsity_structure=sparsity_structure |
| 89 | + ) |
| 90 | + return Sparse24BitMaskTensor( |
| 91 | + shape=shape, |
| 92 | + compressed=compressed, |
| 93 | + bitmask=bitmask, |
| 94 | + ) |
| 95 | + |
| 96 | + @staticmethod |
| 97 | + def from_compressed_data( |
| 98 | + shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor |
| 99 | + ) -> "Sparse24BitMaskTensor": |
| 100 | + """ |
| 101 | + :param shape: shape of the dense tensor (can be a list or a tensor) |
| 102 | + :param compressed: 2d tensor of non-zero values |
| 103 | + :param bitmask: 2d bitmask of non-zero values |
| 104 | + :return: instantiated Sparse24BitMaskTensor |
| 105 | + """ |
| 106 | + if isinstance(shape, Tensor): |
| 107 | + shape = shape.tolist() |
| 108 | + return Sparse24BitMaskTensor( |
| 109 | + shape=shape, compressed=compressed, bitmask=bitmask |
| 110 | + ) |
| 111 | + |
| 112 | + def decompress(self) -> Tensor: |
| 113 | + """ |
| 114 | + :return: reconstructed dense tensor |
| 115 | + """ |
| 116 | + return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape) |
| 117 | + |
| 118 | + def curr_memory_size_bytes(self) -> int: |
| 119 | + """ |
| 120 | + :return: size in bytes required to store compressed tensor on disk |
| 121 | + """ |
| 122 | + |
| 123 | + def sizeof_tensor(a: Tensor) -> int: |
| 124 | + return a.element_size() * a.nelement() |
| 125 | + |
| 126 | + return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask) |
| 127 | + |
| 128 | + def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]: |
| 129 | + """ |
| 130 | + :param name_prefix: name of original tensor to store compressed weight as |
| 131 | + :return: dict of compressed data for the stored weight |
| 132 | + """ |
| 133 | + if name_prefix.endswith(".weight"): |
| 134 | + name_prefix = name_prefix[: -len(".weight")] |
| 135 | + return { |
| 136 | + merge_names(name_prefix, "shape"): torch.tensor( |
| 137 | + self.shape, device=device |
| 138 | + ).reshape(-1, 1), |
| 139 | + merge_names(name_prefix, "compressed"): self.compressed.to(device), |
| 140 | + merge_names(name_prefix, "bitmask"): self.bitmask.to(device), |
| 141 | + } |
| 142 | + |
| 143 | + def __repr__(self) -> str: |
| 144 | + return f"BitMaskTensor(shape={self.shape}, compressed=True)" |
| 145 | + |
| 146 | + |
| 147 | +def sparse24_bitmask_compress( |
| 148 | + tensor: Tensor, |
| 149 | + sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR, |
| 150 | +) -> Tuple[Tensor, Tensor, Tensor]: |
| 151 | + """ |
| 152 | + Compresses a dense tensor using bitmask compression |
| 153 | +
|
| 154 | + :param tensor: dense 2D tensor to compress |
| 155 | + :param sparsity_structure: structure of sparsity in the tensor, defaults |
| 156 | + to unstructured, can also be set to `2:4` |
| 157 | + :return: tuple of compressed data representing tensor |
| 158 | + """ |
| 159 | + assert len(tensor.shape) == 2, "Only 2D tensors are supported" |
| 160 | + assert ( |
| 161 | + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR |
| 162 | + ), "Only 2:4 sparsity is supported" |
| 163 | + |
| 164 | + bytemasks = get_24_bytemasks(tensor=tensor) |
| 165 | + |
| 166 | + if tensor.dtype == FP8_DTYPE: |
| 167 | + # acces raw bytes of the tensor |
| 168 | + tensor_view = tensor.view(torch.int8) |
| 169 | + values = tensor_view[bytemasks] |
| 170 | + values = values.view(FP8_DTYPE) |
| 171 | + else: |
| 172 | + values = tensor[bytemasks] |
| 173 | + |
| 174 | + num_rows, num_cols = tensor.shape |
| 175 | + compressed_values = values.reshape(num_rows, num_cols // 2) |
| 176 | + bitmasks_packed = pack_bitmasks(bytemasks) |
| 177 | + return compressed_values, bitmasks_packed |
| 178 | + |
| 179 | + |
| 180 | +def sparse24_bitmask_decompress( |
| 181 | + values: Tensor, bitmasks: Tensor, original_shape: torch.Size |
| 182 | +) -> Tensor: |
| 183 | + """ |
| 184 | + Reconstructs a dense tensor from a compressed one |
| 185 | +
|
| 186 | + :param values: 1d tensor of non-zero values |
| 187 | + :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the |
| 188 | + tensors original shape |
| 189 | + :param original_shape: shape of the dense tensor |
| 190 | + :return: decompressed dense tensor |
| 191 | + """ |
| 192 | + bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape) |
| 193 | + |
| 194 | + decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype) |
| 195 | + decompressed_tensor = decompressed_tensor.to(values.device) |
| 196 | + values = values.flatten() |
| 197 | + if decompressed_tensor.dtype == FP8_DTYPE: |
| 198 | + decompressed_tensor[bytemasks_unpacked] = values |
| 199 | + decompressed_tensor = decompressed_tensor.cuda() |
| 200 | + else: |
| 201 | + decompressed_tensor[bytemasks_unpacked] = values |
| 202 | + return decompressed_tensor |
| 203 | + |
| 204 | + |
| 205 | +def get_24_bytemasks(tensor): |
| 206 | + """ |
| 207 | + Generate a 2:4 sparsity mask for the given tensor. |
| 208 | +
|
| 209 | + This function creates a mask where exactly 2 out of every 4 elements are |
| 210 | + preserved based on their magnitudes. The preserved elements are the ones |
| 211 | + with the highest absolute values in each group of 4 elements. |
| 212 | +
|
| 213 | + :param tensor: The input tensor for which the 2:4 sparsity mask is to be created. |
| 214 | + The tensor can be of any shape but its total number of elements |
| 215 | + must be a multiple of 4. |
| 216 | + :return: A boolean tensor of the same shape as the input tensor, where `True` |
| 217 | + indicates the preserved elements and `False` indicates the pruned elements. |
| 218 | + :raises ValueError: If the total number of elements in the tensor is not a |
| 219 | + multiple of 4. |
| 220 | + """ |
| 221 | + original_dtype = tensor.dtype |
| 222 | + if tensor.dtype == FP8_DTYPE: |
| 223 | + tensor = tensor.view(torch.int8) |
| 224 | + original_shape = tensor.shape |
| 225 | + num_elements = tensor.numel() |
| 226 | + |
| 227 | + if num_elements % 4 != 0: |
| 228 | + raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity") |
| 229 | + |
| 230 | + reshaped_tensor = tensor.view(-1, 4) |
| 231 | + abs_tensor = reshaped_tensor.abs() |
| 232 | + topk_indices = abs_tensor.topk(2, dim=1).indices |
| 233 | + mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool) |
| 234 | + mask.scatter_(1, topk_indices, True) |
| 235 | + mask = mask.view(original_shape) |
| 236 | + tensor = tensor.view(original_dtype) |
| 237 | + |
| 238 | + return mask |
0 commit comments