Skip to content

Commit f0c369a

Browse files
authored
Add 24 sparse bitmask (#235)
* Add: Support for targets and ignore in SparseCompressors Enable: Operations on state_dict to allow composability Add: Composability for compress/decompress pathways Update: Typing for a few methods Add: Composability Test Add: Some testing utils * Review Comments! * More review comments from @dsikka * Fix failing tests * Adds: Fully shardable Sparse24BitMaskCompressor Adds: Sharding test * review comments from @dsikka * Convert SparseBitMaskTensor to a dataclass * Add tests fro int8 Add a requires gpu decorator in testing_utils Enable fp8 tests if gpu available * typo's * Revert function name change
1 parent 7801f00 commit f0c369a

File tree

9 files changed

+612
-39
lines changed

9 files changed

+612
-39
lines changed

src/compressed_tensors/compressors/sparse_compressors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515

1616
from .base import *
1717
from .dense import *
18+
from .sparse_24_bitmask import *
1819
from .sparse_bitmask import *
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import Dict, List, Tuple, Union
17+
18+
import torch
19+
from compressed_tensors.compressors.base import BaseCompressor
20+
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21+
from compressed_tensors.config import CompressionFormat, SparsityStructure
22+
from compressed_tensors.quantization import FP8_DTYPE
23+
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
24+
from torch import Tensor
25+
26+
27+
__all__ = [
28+
"Sparse24BitMaskCompressor",
29+
"Sparse24BitMaskTensor",
30+
"sparse24_bitmask_compress",
31+
"sparse24_bitmask_decompress",
32+
"get_24_bytemasks",
33+
]
34+
35+
36+
@BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value)
37+
class Sparse24BitMaskCompressor(BaseSparseCompressor):
38+
"""
39+
Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d
40+
values tensor, with their locations stored in a 2d bitmask
41+
"""
42+
43+
COMPRESSION_PARAM_NAMES = [
44+
"shape",
45+
"compressed",
46+
"bitmask",
47+
]
48+
49+
def compress_weight(self, name, value):
50+
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
51+
value, self.config.sparsity_structure
52+
)
53+
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
54+
return bitmask_dict
55+
56+
def decompress_weight(self, weight_data):
57+
data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
58+
decompressed = data.decompress()
59+
return decompressed
60+
61+
62+
@dataclass
63+
class Sparse24BitMaskTensor:
64+
"""
65+
Owns compressions and decompression for a single 2:4 sparse
66+
bitmask compressed tensor.
67+
68+
:param shape: shape of dense tensor
69+
:param compressed: 2d tensor of non-zero values
70+
:param bitmask: 2d bitmask of non-zero values
71+
"""
72+
73+
shape: List[int]
74+
compressed: Tensor
75+
bitmask: Tensor
76+
77+
@staticmethod
78+
def from_dense(
79+
tensor: Tensor,
80+
sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
81+
) -> "Sparse24BitMaskTensor":
82+
"""
83+
:param tensor: dense tensor to compress
84+
:return: instantiated compressed tensor
85+
"""
86+
shape = list(tensor.shape)
87+
compressed, bitmask = sparse24_bitmask_compress(
88+
tensor.cpu(), sparsity_structure=sparsity_structure
89+
)
90+
return Sparse24BitMaskTensor(
91+
shape=shape,
92+
compressed=compressed,
93+
bitmask=bitmask,
94+
)
95+
96+
@staticmethod
97+
def from_compressed_data(
98+
shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor
99+
) -> "Sparse24BitMaskTensor":
100+
"""
101+
:param shape: shape of the dense tensor (can be a list or a tensor)
102+
:param compressed: 2d tensor of non-zero values
103+
:param bitmask: 2d bitmask of non-zero values
104+
:return: instantiated Sparse24BitMaskTensor
105+
"""
106+
if isinstance(shape, Tensor):
107+
shape = shape.tolist()
108+
return Sparse24BitMaskTensor(
109+
shape=shape, compressed=compressed, bitmask=bitmask
110+
)
111+
112+
def decompress(self) -> Tensor:
113+
"""
114+
:return: reconstructed dense tensor
115+
"""
116+
return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape)
117+
118+
def curr_memory_size_bytes(self) -> int:
119+
"""
120+
:return: size in bytes required to store compressed tensor on disk
121+
"""
122+
123+
def sizeof_tensor(a: Tensor) -> int:
124+
return a.element_size() * a.nelement()
125+
126+
return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask)
127+
128+
def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
129+
"""
130+
:param name_prefix: name of original tensor to store compressed weight as
131+
:return: dict of compressed data for the stored weight
132+
"""
133+
if name_prefix.endswith(".weight"):
134+
name_prefix = name_prefix[: -len(".weight")]
135+
return {
136+
merge_names(name_prefix, "shape"): torch.tensor(
137+
self.shape, device=device
138+
).reshape(-1, 1),
139+
merge_names(name_prefix, "compressed"): self.compressed.to(device),
140+
merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
141+
}
142+
143+
def __repr__(self) -> str:
144+
return f"BitMaskTensor(shape={self.shape}, compressed=True)"
145+
146+
147+
def sparse24_bitmask_compress(
148+
tensor: Tensor,
149+
sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
150+
) -> Tuple[Tensor, Tensor, Tensor]:
151+
"""
152+
Compresses a dense tensor using bitmask compression
153+
154+
:param tensor: dense 2D tensor to compress
155+
:param sparsity_structure: structure of sparsity in the tensor, defaults
156+
to unstructured, can also be set to `2:4`
157+
:return: tuple of compressed data representing tensor
158+
"""
159+
assert len(tensor.shape) == 2, "Only 2D tensors are supported"
160+
assert (
161+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
162+
), "Only 2:4 sparsity is supported"
163+
164+
bytemasks = get_24_bytemasks(tensor=tensor)
165+
166+
if tensor.dtype == FP8_DTYPE:
167+
# acces raw bytes of the tensor
168+
tensor_view = tensor.view(torch.int8)
169+
values = tensor_view[bytemasks]
170+
values = values.view(FP8_DTYPE)
171+
else:
172+
values = tensor[bytemasks]
173+
174+
num_rows, num_cols = tensor.shape
175+
compressed_values = values.reshape(num_rows, num_cols // 2)
176+
bitmasks_packed = pack_bitmasks(bytemasks)
177+
return compressed_values, bitmasks_packed
178+
179+
180+
def sparse24_bitmask_decompress(
181+
values: Tensor, bitmasks: Tensor, original_shape: torch.Size
182+
) -> Tensor:
183+
"""
184+
Reconstructs a dense tensor from a compressed one
185+
186+
:param values: 1d tensor of non-zero values
187+
:param bitmasks: 2d int8 tensor flagging locations of non-zero values in the
188+
tensors original shape
189+
:param original_shape: shape of the dense tensor
190+
:return: decompressed dense tensor
191+
"""
192+
bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape)
193+
194+
decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
195+
decompressed_tensor = decompressed_tensor.to(values.device)
196+
values = values.flatten()
197+
if decompressed_tensor.dtype == FP8_DTYPE:
198+
decompressed_tensor[bytemasks_unpacked] = values
199+
decompressed_tensor = decompressed_tensor.cuda()
200+
else:
201+
decompressed_tensor[bytemasks_unpacked] = values
202+
return decompressed_tensor
203+
204+
205+
def get_24_bytemasks(tensor):
206+
"""
207+
Generate a 2:4 sparsity mask for the given tensor.
208+
209+
This function creates a mask where exactly 2 out of every 4 elements are
210+
preserved based on their magnitudes. The preserved elements are the ones
211+
with the highest absolute values in each group of 4 elements.
212+
213+
:param tensor: The input tensor for which the 2:4 sparsity mask is to be created.
214+
The tensor can be of any shape but its total number of elements
215+
must be a multiple of 4.
216+
:return: A boolean tensor of the same shape as the input tensor, where `True`
217+
indicates the preserved elements and `False` indicates the pruned elements.
218+
:raises ValueError: If the total number of elements in the tensor is not a
219+
multiple of 4.
220+
"""
221+
original_dtype = tensor.dtype
222+
if tensor.dtype == FP8_DTYPE:
223+
tensor = tensor.view(torch.int8)
224+
original_shape = tensor.shape
225+
num_elements = tensor.numel()
226+
227+
if num_elements % 4 != 0:
228+
raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity")
229+
230+
reshaped_tensor = tensor.view(-1, 4)
231+
abs_tensor = reshaped_tensor.abs()
232+
topk_indices = abs_tensor.topk(2, dim=1).indices
233+
mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
234+
mask.scatter_(1, topk_indices, True)
235+
mask = mask.view(original_shape)
236+
tensor = tensor.view(original_dtype)
237+
238+
return mask

src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414

1515
from typing import Dict, List, Tuple, Union
1616

17-
import numpy
1817
import torch
1918
from compressed_tensors.compressors.base import BaseCompressor
2019
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
2120
from compressed_tensors.config import CompressionFormat
2221
from compressed_tensors.quantization import FP8_DTYPE
23-
from compressed_tensors.utils import merge_names
22+
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
2423
from torch import Tensor
2524

2625

@@ -29,8 +28,6 @@
2928
"BitmaskTensor",
3029
"bitmask_compress",
3130
"bitmask_decompress",
32-
"pack_bitmasks",
33-
"unpack_bitmasks",
3431
]
3532

3633

@@ -164,37 +161,3 @@ def bitmask_decompress(
164161
decompressed_tensor[bytemasks_unpacked] = values
165162

166163
return decompressed_tensor
167-
168-
169-
def pack_bitmasks(bytemasks: Tensor) -> Tensor:
170-
"""
171-
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
172-
compressed to R x ceil(C/8)
173-
:param bytemasks: mask tensor where each byte corresponds to a weight
174-
:return: mask tensor where each bit corresounds to a weight
175-
"""
176-
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
177-
packed_bits_torch = torch.from_numpy(packed_bits_numpy)
178-
179-
return packed_bits_torch
180-
181-
182-
def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor:
183-
"""
184-
Converts a bitmask tensor back to a bytemask tensor for use during decompression
185-
186-
:param packed_bitmasks: mask tensor where each bit corresponds to a weight
187-
:param original_shape: dense shape to decompress to
188-
:return: boolean mask of weights in the original dense shape
189-
"""
190-
# Unpack the bits
191-
unpacked_bits = numpy.unpackbits(
192-
packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little"
193-
)
194-
195-
# Reshape to match the original shape
196-
unpacked_bitmasks_torch = torch.from_numpy(
197-
unpacked_bits.reshape(original_shape).astype(bool)
198-
)
199-
200-
return unpacked_bitmasks_torch

src/compressed_tensors/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
# flake8: noqa
1616
from .base import *
1717
from .dense import *
18+
from .sparse_24_bitmask import *
1819
from .sparse_bitmask import *

src/compressed_tensors/config/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
class CompressionFormat(Enum):
2727
dense = "dense"
2828
sparse_bitmask = "sparse-bitmask"
29+
sparse_24_bitmask = "sparse-24-bitmask"
2930
int_quantized = "int-quantized"
3031
float_quantized = "float-quantized"
3132
naive_quantized = "naive-quantized"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
from compressed_tensors.config import (
18+
CompressionFormat,
19+
SparsityCompressionConfig,
20+
SparsityStructure,
21+
)
22+
23+
24+
__all__ = ["Sparse24BitMaskConfig"]
25+
26+
27+
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value)
28+
class Sparse24BitMaskConfig(SparsityCompressionConfig):
29+
"""
30+
Configuration for storing a 24 sparse model using
31+
bytemask compression
32+
33+
:param global_sparsity: average sparsity of the entire model
34+
:param sparsity_structure: structure of the sparsity, should always be
35+
"2:4" for this compression format
36+
"""
37+
38+
format: str = CompressionFormat.sparse_24_bitmask.value
39+
global_sparsity: Optional[float] = 0.0
40+
sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value

0 commit comments

Comments
 (0)