Skip to content

Commit ba62d8e

Browse files
committed
Align Int4Tensor implementation details with the design of Float8Tensor
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * defined `tensor_data_names` and `tensor_attribute_names` so we can remove some of the implementations from TorchAOBaseTensor * Migrated op implementation and tests from #2387 Note: This is just refactoring Int4Tensor, no BC related changes in this PR Int4Tensor path is exposed in version 2 of `Int4WeightOnlyConfig` (default version is still 1, which is using the old AQT path Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
1 parent c086ade commit ba62d8e

File tree

5 files changed

+430
-156
lines changed

5 files changed

+430
-156
lines changed

test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Lines changed: 110 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,21 @@
88

99
import torch
1010
from torch.testing._internal.common_utils import (
11-
TestCase,
11+
instantiate_parametrized_tests,
12+
parametrize,
1213
run_tests,
1314
)
1415

15-
from torchao.quantization import (
16-
Int4WeightOnlyConfig,
17-
quantize_,
18-
)
16+
from torchao.quantization import Int4WeightOnlyConfig, quantize_
1917
from torchao.quantization.utils import compute_error
20-
from torchao.utils import (
21-
TORCH_VERSION_AT_LEAST_2_8,
22-
is_sm_at_least_90,
23-
)
18+
from torchao.testing.utils import TorchAOIntegrationTestCase
19+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_90
2420

2521

2622
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2723
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2824
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29-
class TestInt4Tensor(TestCase):
25+
class TestInt4Tensor(TorchAOIntegrationTestCase):
3026
def setUp(self):
3127
self.config = Int4WeightOnlyConfig(
3228
group_size=128,
@@ -61,50 +57,46 @@ def test_slice(self):
6157
quantize_(dummy, self.config)
6258
weight1 = dummy.weight.narrow(0, 0, 64)
6359
weight2 = dummy.weight.narrow(1, 0, 128)
64-
self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64))
60+
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64))
6561
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
66-
self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64))
62+
self.assertEqual(weight1.zero_point, dummy.weight.zero_point.narrow(1, 0, 64))
63+
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 64))
6764
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))
65+
self.assertEqual(weight2.zero_point, dummy.weight.zero_point.narrow(0, 0, 1))
6866

6967
# check for sliced weight, before and after float8 quantization
7068
# does not differ too much
7169
input = torch.randn(2, 256, dtype=dtype, device=device)
7270
res_ref = dummy1(input)
73-
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
71+
dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False)
7472
res = dummy(input)
7573
assert compute_error(res, res_ref) > 20
7674

7775
input = torch.randn(2, 128, dtype=dtype, device=device)
7876
res_ref = dummy2(input)
79-
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
77+
dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False)
8078
res = dummy(input)
8179
assert compute_error(res, res_ref) > 15
8280

83-
def test_slice_and_copy_(self):
81+
def test_slice_preserves_aliasing(self):
82+
config = self.config
8483
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
8584
l.weight = torch.nn.Parameter(
8685
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
8786
)
88-
quantize_(l, self.config)
87+
quantize_(l, config)
8988
param = l.weight
9089
param_data = param.data
9190
param_data = param_data.narrow(0, 0, 512)
92-
assert param.data._data.data_ptr() == param_data._data.data_ptr()
91+
# Making sure the aliasing is preserved in sliced quantized Tensor
92+
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
9393
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
9494
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
95-
orig_value = param.data._data[0][0].item()
96-
97-
# dummy_l has random input (shouldn't be 0)
98-
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
99-
quantize_(dummy_l, self.config)
100-
quantized = dummy_l.weight
101-
quantized = quantized.narrow(0, 0, 512)
10295

103-
param_data.copy_(quantized)
104-
105-
# making sure param.data is updated
106-
assert param.data._data[0][0] != orig_value
96+
def test_slice_and_copy_similar_to_vllm(self):
97+
self._test_slice_and_copy_similar_to_vllm(self.config)
10798

99+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
108100
def test_bmm(self):
109101
class M(torch.nn.Module):
110102
def __init__(self, weight):
@@ -126,20 +118,103 @@ def forward(self, x):
126118
quantized = m(input)
127119
self.assertTrue(compute_error(original, quantized) > 18)
128120

129-
def test_to_device(self):
121+
@parametrize(
122+
"sizes",
123+
[
124+
((128,), 256, 128),
125+
((32, 128), 64, 256),
126+
((2, 32, 128), 64, 256),
127+
],
128+
)
129+
def test_to_device(self, sizes):
130+
config = self.config
131+
M, N, K = sizes
132+
dtype = torch.bfloat16
130133
for device in self.GPU_DEVICES:
131-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
132-
quantize_(linear, self.config)
134+
input_tensor = torch.randn(*M, K, dtype=dtype, device=device)
135+
linear = torch.nn.Linear(K, N, dtype=dtype)
136+
quantize_(linear, config)
133137
linear.to(device)
138+
linear(input_tensor)
134139

135-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
136-
quantize_(linear, self.config)
140+
linear = torch.nn.Linear(K, N, dtype=dtype)
141+
quantize_(linear, config)
137142
linear.to(device=device)
143+
linear(input_tensor)
138144

139-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
140-
quantize_(linear, self.config)
145+
linear = torch.nn.Linear(K, N, dtype=dtype)
146+
quantize_(linear, config)
141147
linear.to(device)
148+
linear(input_tensor)
149+
150+
@parametrize(
151+
"sizes",
152+
[
153+
((128,), 256, 128),
154+
((32, 128), 64, 256),
155+
((2, 32, 128), 64, 256),
156+
],
157+
)
158+
def test_cat(self, sizes):
159+
config = self.config
160+
dtype = torch.bfloat16
161+
device = "cuda"
162+
M, N, K = sizes
163+
linear1 = torch.nn.Linear(K, N, dtype=dtype, device=device)
164+
linear2 = torch.nn.Linear(K, N, dtype=dtype, device=device)
165+
input_cat1 = torch.randn(*M, K, dtype=dtype, device=device)
166+
167+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
168+
dummy_linear1 = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
169+
170+
dummy_linear1.weight = torch.nn.Parameter(cat_weight1)
171+
quantize_(dummy_linear1, config)
172+
173+
quantize_(linear1, config)
174+
quantize_(linear2, config)
175+
176+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
177+
self.assertTrue(cat_qweight1.shape, (2 * N, K))
178+
self.assertEqual(
179+
dummy_linear1.weight.qdata,
180+
cat_qweight1.qdata,
181+
)
182+
self.assertEqual(
183+
dummy_linear1.weight.scale,
184+
cat_qweight1.scale,
185+
)
186+
self.assertEqual(
187+
dummy_linear1.weight.zero_point,
188+
cat_qweight1.zero_point,
189+
)
190+
191+
# making sure cat_qweight1 can be used for inference
192+
dummy_linear1.weight = torch.nn.Parameter(cat_qweight1, requires_grad=False)
193+
dummy_linear1(input_cat1)
194+
195+
# align the scale and zero_point before concatenation
196+
linear2.weight.scale = linear1.weight.scale
197+
linear2.weight.zero_point = linear1.weight.zero_point
198+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
199+
self.assertTrue(cat_qweight2.shape, (N, 2 * K))
200+
ref_data = torch.cat(
201+
[
202+
linear1.weight.qdata,
203+
linear2.weight.qdata,
204+
],
205+
dim=1,
206+
)
207+
ref_scale = linear1.weight.scale
208+
ref_zero_point = linear1.weight.zero_point
209+
self.assertEqual(cat_qweight2.qdata, ref_data)
210+
self.assertEqual(cat_qweight2.scale, ref_scale)
211+
self.assertEqual(cat_qweight2.zero_point, ref_zero_point)
212+
213+
def test_moe_weight_reshape_ops(self):
214+
self._test_moe_weight_reshape_ops(self.config)
215+
142216

217+
instantiate_parametrized_tests(TestInt4Tensor)
143218

144219
if __name__ == "__main__":
145220
run_tests()

test/test_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, qdata, attr, device=None):
7676
self.qdata = qdata
7777
self.attr = attr
7878

79-
l = torch.nn.Linear(1, 1)
79+
l = torch.nn.Linear(2, 3)
8080
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
8181
lp_tensor = l.weight
8282
# test __tensor_flatten__ and __tensor_unflatten__
@@ -107,18 +107,24 @@ def __init__(self, qdata, attr, device=None):
107107
# explicitly testing aten.alias
108108
lp_tensor = torch.ops.aten.alias(lp_tensor)
109109
lp_tensor = lp_tensor.clone()
110+
# making qdata not contiguous
111+
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1).contiguous()
112+
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1)
113+
self.assertFalse(lp_tensor.qdata.is_contiguous())
110114
lp_tensor = lp_tensor.contiguous()
115+
# making sure contiguous call works
116+
self.assertTrue(lp_tensor.qdata.is_contiguous())
111117

112118
# copy_
113-
another_tensor = torch.nn.Linear(1, 1).weight
119+
another_tensor = torch.nn.Linear(2, 3).weight
114120
# attribute has to be the same
115121
another_lp_tensor = MyTensor(another_tensor, "attr")
116122
# initially tensor values are not the same
117-
self.assertNotEqual(lp_tensor.qdata[0], another_lp_tensor.qdata[0])
123+
self.assertNotEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
118124
lp_tensor.copy_(another_lp_tensor)
119125
self.assertEqual(lp_tensor.attr, "attr")
120126
# after copy_, the tensor values should match
121-
self.assertEqual(lp_tensor.qdata[0], another_lp_tensor.qdata[0])
127+
self.assertEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
122128

123129

124130
if __name__ == "__main__":

torchao/quantization/quant_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
11601160
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
11611161

11621162
if config.VERSION == 2:
1163+
block_size = list(block_size)
11631164
if packing_format == PackingFormat.PRESHUFFLED:
11641165
new_weight = Int4PreshuffledTensor.from_float(
11651166
weight,
@@ -1168,7 +1169,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
11681169
)
11691170
return new_weight
11701171
elif packing_format == PackingFormat.PLAIN:
1171-
new_weight = Int4Tensor.from_float(
1172+
new_weight = Int4Tensor.from_hp(
11721173
weight,
11731174
block_size,
11741175
)
@@ -2212,7 +2213,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
22122213
activation_dtype=torch.bfloat16,
22132214
)
22142215
else:
2215-
weight = Int4Tensor.from_float(
2216+
weight = Int4Tensor.from_hp(
22162217
module.weight,
22172218
config.block_size,
22182219
)

0 commit comments

Comments
 (0)