Skip to content

Commit cffba61

Browse files
authored
Add from_int4_tensor in Int4PreshuffledTensor (#2978)
Summary: Added a classmethod `from_int4_tensor` to convert a plain `Int4Tensor` to `Int4PreshuffledTensor` This is in preparation for supporting Int4PreshuffledTensor in vllm, which requires the tensor to be sliced before inference, see https://github.com/pytorch/ao/blob/186aeb01664687d14108ada420c475cc783e1643/torchao/testing/utils.py#L429 for details but Int4PreshuffledTensor can't be easiliy sliced while also preserving alias, so we plan to slice the Plain int4 tensor instead and then convert to Int4PreshuffledTensor at a later stage. Next PR is going to add a top level API in prototype to convert from int4 tensor to int4 preshuffled tensor Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py -k test_from_int4_tensor Reviewers: Subscribers: Tasks: Tags:
1 parent f1e118b commit cffba61

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import copy
78
import tempfile
89
import unittest
910

@@ -17,6 +18,7 @@
1718

1819
from torchao.quantization import (
1920
Float8DynamicActivationInt4WeightConfig,
21+
Int4PreshuffledTensor,
2022
Int4WeightOnlyConfig,
2123
quantize_,
2224
)
@@ -82,6 +84,34 @@ def forward(self, x):
8284
quantized = m(input)
8385
self.assertTrue(compute_error(original, quantized) > 18)
8486

87+
def test_from_int4_tensor(self):
88+
"""Test that constructing Int4PreshuffledTensor from Int4Tensor
89+
is the same as quantizing the original weight to Int4PreshuffledTensor
90+
"""
91+
int4_config = Int4WeightOnlyConfig(
92+
group_size=128,
93+
int4_packing_format="plain",
94+
)
95+
int4_preshuffled_config = Int4WeightOnlyConfig(
96+
group_size=128,
97+
int4_packing_format="preshuffled",
98+
)
99+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
100+
linear2 = copy.deepcopy(linear1)
101+
102+
quantize_(linear1, int4_config)
103+
quantize_(linear2, int4_preshuffled_config)
104+
105+
# now convert the linear1.weight to Int4PreshuffledTensor
106+
w1_preshuffled = Int4PreshuffledTensor.from_int4_tensor(linear1.weight)
107+
linear1.weight = torch.nn.Parameter(w1_preshuffled, requires_grad=False)
108+
109+
example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),)
110+
111+
output1 = linear1(*example_inputs)
112+
output2 = linear2(*example_inputs)
113+
self.assertEqual(output1, output2)
114+
85115
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
86116
def test_to_device(self, config):
87117
for device in self.GPU_DEVICES:

torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
1314
from torchao.utils import (
1415
TorchAOBaseTensor,
1516
)
@@ -27,6 +28,7 @@
2728
):
2829
quantize_int4_preshuffle = None
2930
quantize_fp8_row = None
31+
pack_int4 = None
3032
else:
3133
from fbgemm_gpu.experimental.gen_ai.quantize import (
3234
quantize_fp8_row,
@@ -185,6 +187,38 @@ def from_hp(
185187
row_scale=row_scale,
186188
)
187189

190+
@classmethod
191+
def from_int4_tensor(
192+
cls,
193+
tensor: Int4Tensor,
194+
):
195+
assert isinstance(tensor, Int4Tensor), (
196+
f"Only conversion from Int4Tensor is supportd, got: {tensor}"
197+
)
198+
# currently Int4Tensor only supports weight only, we can extend it to fp8-int4 a bit later
199+
qdata = tensor.qdata
200+
group_scale = tensor.scale
201+
group_zero = tensor.zero_point
202+
block_size = tensor.block_size
203+
original_shape = tensor.shape
204+
row_scale = None
205+
206+
# Set scales to activation type.
207+
group_scale = group_scale.to(torch.bfloat16)
208+
group_zero = group_zero.to(torch.bfloat16)
209+
# pack weights and scales into efficient preshuffled format
210+
preshuffled_qdata, group_scale = torch.ops.fbgemm.preshuffle_i4(
211+
qdata, group_scale
212+
)
213+
return Int4PreshuffledTensor(
214+
qdata=preshuffled_qdata,
215+
group_scale=group_scale,
216+
block_size=block_size,
217+
shape=original_shape,
218+
group_zero=group_zero,
219+
row_scale=row_scale,
220+
)
221+
188222

189223
implements = Int4PreshuffledTensor.implements
190224

0 commit comments

Comments
 (0)