Skip to content

Commit b40fd97

Browse files
authored
Int4 sparse marlin tensor (#2771)
* added marlin sparse to packing format, inital commit * deleting unnecessary functions * packing * linear * add call to from_hp * unit test * fix test_linear * formatting * remove comments * update VERSION to version * fix module path unit test * adding sizes to linear unit test * move pre_process and from_plain to from_hp * compile test
1 parent 69e71d9 commit b40fd97

File tree

6 files changed

+341
-0
lines changed

6 files changed

+341
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
15+
run_tests,
16+
)
17+
18+
from torchao.quantization import (
19+
Int4WeightOnlyConfig,
20+
quantize_,
21+
)
22+
from torchao.quantization.utils import compute_error
23+
from torchao.sparsity.sparse_api import apply_fake_sparsity
24+
from torchao.utils import (
25+
TORCH_VERSION_AT_LEAST_2_8,
26+
)
27+
28+
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
29+
group_size=128,
30+
packing_format="marlin_sparse",
31+
version=2,
32+
)
33+
34+
35+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
36+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
37+
class TestInt4MarlinSparseTensor(TestCase):
38+
def setUp(self):
39+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
40+
41+
@parametrize("config", [BF16_ACT_CONFIG])
42+
@parametrize(
43+
"sizes",
44+
[
45+
((128,), 256, 128),
46+
((32, 128), 512, 128),
47+
((2, 32, 128), 256, 12),
48+
],
49+
)
50+
def test_linear(self, config, sizes):
51+
dtype = torch.float16
52+
device = "cuda"
53+
54+
M, N, K = sizes
55+
input = torch.randn(*M, K, dtype=dtype, device=device)
56+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
57+
58+
apply_fake_sparsity(linear)
59+
original = linear(input)
60+
quantize_(linear, config)
61+
quantized = linear(input)
62+
self.assertTrue(compute_error(original, quantized) > 20)
63+
64+
compiled_linear = torch.compile(linear)
65+
quantized_and_compiled = compiled_linear(input)
66+
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
67+
68+
@unittest.skip("Fix later")
69+
@parametrize("config", [BF16_ACT_CONFIG])
70+
def test_to_device(self, config):
71+
for device in self.GPU_DEVICES:
72+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
73+
quantize_(linear, config)
74+
linear.to(device)
75+
76+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
77+
quantize_(linear, config)
78+
linear.to(device=device)
79+
80+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
81+
quantize_(linear, config)
82+
linear.to(device)
83+
84+
@parametrize("config", [BF16_ACT_CONFIG])
85+
def test_module_path(self, config):
86+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
87+
quantize_(linear.cuda(), config)
88+
self.assertEqual(
89+
str(type(linear.weight)),
90+
"<class 'torchao.quantization.Int4MarlinSparseTensor'>",
91+
)
92+
93+
with tempfile.NamedTemporaryFile() as f:
94+
torch.save(linear.state_dict(), f)
95+
f.seek(0)
96+
state_dict = torch.load(f)
97+
self.assertEqual(
98+
str(type(state_dict["weight"])),
99+
"<class 'torchao.quantization.Int4MarlinSparseTensor'>",
100+
)
101+
102+
103+
instantiate_parametrized_tests(TestInt4MarlinSparseTensor)
104+
105+
106+
if __name__ == "__main__":
107+
run_tests()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
)
9191
from .quantize_.workflows import (
9292
Float8Tensor,
93+
Int4MarlinSparseTensor,
9394
Int4PreshuffledTensor,
9495
Int4Tensor,
9596
)
@@ -159,6 +160,7 @@
159160
# tensor subclasses
160161
"Int4Tensor",
161162
"Int4PreshuffledTensor",
163+
"Int4MarlinSparseTensor",
162164
"Float8Tensor",
163165
# smooth quant - subject to change
164166
"get_scale",

torchao/quantization/quant_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
)
7373
from torchao.quantization.quantize_.workflows import (
7474
Float8Tensor,
75+
Int4MarlinSparseTensor,
7576
Int4PreshuffledTensor,
7677
Int4Tensor,
7778
QuantizeTensorToFloat8Kwargs,
@@ -1068,6 +1069,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
10681069
block_size,
10691070
)
10701071
return new_weight
1072+
elif packing_format == PackingFormat.MARLIN_SPARSE:
1073+
new_weight = Int4MarlinSparseTensor.from_hp(
1074+
weight,
1075+
block_size,
1076+
)
1077+
return new_weight
10711078
else:
10721079
raise ValueError(f"Unsupported packing format: {packing_format}")
10731080

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ class PackingFormat(str, Enum):
3030
preshuffled is referring to the preshuffled format used by fbgemm kernels
3131
"""
3232
PRESHUFFLED = "preshuffled"
33+
34+
"""
35+
marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization
36+
"""
37+
MARLIN_SPARSE = "marlin_sparse"

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
Float8Tensor,
33
QuantizeTensorToFloat8Kwargs,
44
)
5+
from .int4.int4_marlin_sparse_tensor import (
6+
Int4MarlinSparseTensor,
7+
)
58
from .int4.int4_preshuffled_tensor import (
69
Int4PreshuffledTensor,
710
)
@@ -12,6 +15,7 @@
1215
__all__ = [
1316
"Int4Tensor",
1417
"Int4PreshuffledTensor",
18+
"Int4MarlinSparseTensor",
1519
"Float8Tensor",
1620
"QuantizeTensorToFloat8Kwargs",
1721
]
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import List
9+
10+
import torch
11+
12+
from torchao.quantization.quant_primitives import (
13+
MappingType,
14+
choose_qparams_affine,
15+
quantize_affine,
16+
)
17+
from torchao.utils import TorchAOBaseTensor
18+
19+
__all__ = [
20+
"Int4MarlinSparseTensor",
21+
]
22+
23+
aten = torch.ops.aten
24+
25+
26+
class Int4MarlinSparseTensor(TorchAOBaseTensor):
27+
tensor_data_names = ["qdata", "scale", "zero_point", "meta"]
28+
tensor_attribute_names = ["block_size", "num_bits", "shape"]
29+
30+
def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape):
31+
kwargs = {}
32+
kwargs["device"] = qdata.device
33+
kwargs["dtype"] = scale.dtype
34+
kwargs["requires_grad"] = False
35+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
36+
37+
def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape):
38+
self.qdata = qdata
39+
self.scale = scale
40+
self.zero_point = zero_point
41+
self.meta = meta
42+
self.block_size = block_size
43+
self.num_bits = num_bits
44+
45+
def _quantization_type(self):
46+
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
47+
48+
@classmethod
49+
def from_hp(
50+
cls,
51+
w: torch.Tensor,
52+
block_size: List[int],
53+
):
54+
from torchao.sparsity.marlin import (
55+
const,
56+
inject_24, # avoid circular import
57+
pack_to_marlin_24,
58+
)
59+
60+
"""Preprocess the input tensor to be in the correct format for the Marlin sparse kernel.
61+
- 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format
62+
- 2º: tensor is injected with 2:4 sparsity
63+
- 3º: transposes it again because the quantization process will compute the scales for dim=-1
64+
"""
65+
66+
w_t = w.t()
67+
w_24, _ = inject_24(w_t, *w_t.shape)
68+
preprocessed_w = w_24.t()
69+
70+
assert block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1], (
71+
f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}"
72+
)
73+
74+
quant_min = 0
75+
quant_max = 15
76+
target_dtype = torch.int32
77+
78+
scale, zero_point = choose_qparams_affine(
79+
input=preprocessed_w,
80+
mapping_type=MappingType.SYMMETRIC,
81+
block_size=block_size,
82+
target_dtype=target_dtype,
83+
quant_min=quant_min,
84+
quant_max=quant_max,
85+
eps=1e-6,
86+
)
87+
88+
wq = quantize_affine(
89+
input=preprocessed_w,
90+
block_size=block_size,
91+
scale=scale,
92+
zero_point=zero_point,
93+
output_dtype=target_dtype,
94+
quant_min=quant_min,
95+
quant_max=quant_max,
96+
)
97+
98+
scale = scale.to(w.dtype)
99+
zero_point = zero_point.to(w.dtype)
100+
101+
# Linear layers are (in_features, out_features) but the qdata that is reaching this point
102+
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
103+
q_w_24 = wq.t()
104+
# addressing the case when scale has dimension 1, happens when
105+
# weight_shape[-1] == group_size == 128
106+
if scale.ndim == 1:
107+
scale = scale.reshape(scale.shape[0], -1)
108+
109+
scale_t = scale.t()
110+
111+
if not torch.cuda.get_device_capability()[0] >= 8:
112+
raise ValueError(
113+
f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel."
114+
)
115+
116+
if q_w_24.dtype != torch.int32:
117+
raise ValueError("Only `torch.int32` weights are supported.")
118+
119+
in_features, out_features = q_w_24.shape
120+
if in_features % 128 != 0 or out_features != 256 == 0:
121+
raise ValueError(
122+
"`in_features` must be divisible by 64 and `out_features` by 256."
123+
)
124+
125+
# NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8
126+
# will require a bit more work to get our current quantization flow to work with it.
127+
# Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main
128+
num_bits = 4 if torch.max(q_w_24) < 16 else -1
129+
if num_bits not in [4]:
130+
raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.")
131+
132+
group_size = in_features // scale_t.shape[0]
133+
if group_size == 0:
134+
group_size = in_features
135+
assert group_size <= in_features, (
136+
"Group size must be less than or equal to in_features."
137+
)
138+
139+
if group_size not in const.SUPPORTED_GROUP_SIZES:
140+
raise ValueError(
141+
f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}."
142+
)
143+
144+
# Compress quantized weight to marlin 2:4 format
145+
marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(
146+
q_w_24, scale_t, num_bits, group_size
147+
)
148+
149+
return cls(
150+
qdata=marlin_24_q_w_comp,
151+
scale=marlin_24_s,
152+
zero_point=zero_point,
153+
meta=meta,
154+
block_size=group_size,
155+
shape=q_w_24.shape,
156+
num_bits=num_bits,
157+
)
158+
159+
160+
implements = Int4MarlinSparseTensor.implements
161+
162+
163+
@implements([torch.nn.functional.linear, aten.linear.default])
164+
def _(func, types, args, kwargs):
165+
from torchao.ops import marlin_24_gemm
166+
from torchao.sparsity.marlin import marlin_24_workspace
167+
168+
input_tensor, weight_tensor, bias = (
169+
args[0],
170+
args[1],
171+
args[2] if len(args) > 2 else None,
172+
)
173+
assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous"
174+
assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous"
175+
assert weight_tensor.zero_point.is_contiguous(), (
176+
"Expected zero_point to be contiguous"
177+
)
178+
179+
sparse_w_int4 = weight_tensor.qdata
180+
scale = weight_tensor.scale
181+
meta = weight_tensor.meta
182+
original_shape = weight_tensor.shape
183+
num_bits = weight_tensor.num_bits
184+
185+
# Folds batch dimension into the first dimension
186+
input_2d = input_tensor.view(-1, input_tensor.shape[-1])
187+
188+
size_m = input_2d.shape[0]
189+
size_n = scale.shape[1]
190+
size_k = input_2d.shape[1]
191+
workspace_24 = marlin_24_workspace(original_shape[1])
192+
193+
out = marlin_24_gemm(
194+
input_2d,
195+
sparse_w_int4,
196+
meta,
197+
scale,
198+
workspace_24,
199+
num_bits,
200+
size_m,
201+
size_n,
202+
size_k,
203+
)
204+
205+
# Unfold the batch dimension
206+
out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],))
207+
208+
if bias is not None:
209+
out += bias.to(out.dtype)
210+
return out
211+
212+
213+
Int4MarlinSparseTensor.__module__ = "torchao.quantization"
214+
215+
# Allow a model with Int4MarlinSparseTensor weights to be loaded with `weights_only=True`
216+
torch.serialization.add_safe_globals([Int4MarlinSparseTensor])

0 commit comments

Comments
 (0)