Skip to content

Commit 8e88f26

Browse files
metascroyliangel-02
authored andcommitted
Add IntxUnpackedTensor (#2732)
* add intx unpacked tensor * up * up * up * up * up
1 parent de3e0b6 commit 8e88f26

File tree

7 files changed

+488
-7
lines changed

7 files changed

+488
-7
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
IntxWeightOnlyConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.granularity import PerGroup
20+
from torchao.quantization.utils import compute_error
21+
from torchao.utils import (
22+
TORCH_VERSION_AT_LEAST_2_8,
23+
)
24+
25+
26+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
27+
class TestIntxUnpackedTensor(TestCase):
28+
def setUp(self):
29+
self.config = IntxWeightOnlyConfig(
30+
weight_dtype=torch.int4,
31+
granularity=PerGroup(32),
32+
version=2,
33+
)
34+
35+
def test_embedding(self):
36+
dtype = torch.bfloat16
37+
device = "cpu"
38+
input = torch.randint(low=0, high=128, size=(10,), device=device)
39+
embedding = torch.nn.Embedding(128, 256, dtype=dtype, device=device)
40+
original = embedding(input)
41+
quantize_(embedding, self.config)
42+
quantized = embedding(input)
43+
error = compute_error(original, quantized)
44+
self.assertTrue(error > 20)
45+
46+
def test_linear(self):
47+
dtype = torch.bfloat16
48+
device = "cpu"
49+
input = torch.randn(1, 128, dtype=dtype, device=device)
50+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
51+
original = linear(input)
52+
quantize_(linear, self.config)
53+
quantized = linear(input)
54+
error = compute_error(original, quantized)
55+
self.assertTrue(error > 20)
56+
57+
def test_slice(self):
58+
dtype = torch.bfloat16
59+
device = "cpu"
60+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
61+
62+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
63+
dummy1.weight = torch.nn.Parameter(
64+
dummy.weight.narrow(0, 0, 64), requires_grad=False
65+
)
66+
67+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
68+
dummy2.weight = torch.nn.Parameter(
69+
dummy.weight.narrow(1, 0, 128), requires_grad=False
70+
)
71+
72+
quantize_(dummy, self.config)
73+
weight1 = dummy.weight.narrow(0, 0, 64)
74+
weight2 = dummy.weight.narrow(1, 0, 128)
75+
76+
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64))
77+
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64))
78+
79+
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128))
80+
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(1, 0, 4))
81+
82+
# check for sliced weight, before and after float8 quantization
83+
# does not differ too much
84+
input = torch.randn(2, 256, dtype=dtype, device=device)
85+
res_ref = dummy1(input)
86+
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
87+
res = dummy(input)
88+
assert compute_error(res, res_ref) > 20
89+
90+
input = torch.randn(2, 128, dtype=dtype, device=device)
91+
res_ref = dummy2(input)
92+
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
93+
res = dummy(input)
94+
assert compute_error(res, res_ref) > 15
95+
96+
def test_slice_and_copy_(self):
97+
device = "cpu"
98+
l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16)
99+
l.weight = torch.nn.Parameter(
100+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device=device)
101+
)
102+
quantize_(l, self.config)
103+
param = l.weight
104+
param_data = param.data
105+
param_data = param_data.narrow(0, 0, 512)
106+
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
107+
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
108+
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
109+
orig_value = param.data.qdata[0][0].item()
110+
111+
# dummy_l has random input (shouldn't be 0)
112+
dummy_l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16)
113+
quantize_(dummy_l, self.config)
114+
quantized = dummy_l.weight
115+
quantized = quantized.narrow(0, 0, 512)
116+
117+
param_data.copy_(quantized)
118+
119+
# making sure param.data is updated
120+
assert param.data.qdata[0][0] != orig_value
121+
122+
def test_to_dtype(self):
123+
activations_bf16 = torch.randn(1, 128, dtype=torch.bfloat16)
124+
activations_fp32 = torch.randn(1, 128, dtype=torch.float32)
125+
activations_fp16 = torch.randn(1, 128, dtype=torch.float16)
126+
127+
linear = torch.nn.Linear(128, 256)
128+
quantize_(linear, self.config)
129+
130+
linear.to(dtype=torch.float16)
131+
linear(activations_fp16)
132+
133+
linear.to(dtype=torch.float32)
134+
linear(activations_fp32)
135+
136+
linear.to(dtype=torch.bfloat16)
137+
linear(activations_bf16)
138+
139+
def test_export(self):
140+
linear = torch.nn.Linear(128, 256)
141+
quantize_(linear, self.config)
142+
ep = torch.export.export(linear, (torch.randn(1, 128),))
143+
assert "torch.ops.torchao.dequantize_affine.default" in ep.graph_module.code
144+
145+
146+
if __name__ == "__main__":
147+
run_tests()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
Int4MarlinSparseTensor,
9494
Int4PreshuffledTensor,
9595
Int4Tensor,
96+
IntxUnpackedTensor,
9697
)
9798
from .smoothquant import (
9899
SmoothFakeDynamicallyQuantizedLinear,
@@ -161,6 +162,7 @@
161162
"Int4Tensor",
162163
"Int4PreshuffledTensor",
163164
"Int4MarlinSparseTensor",
165+
"IntxUnpackedTensor",
164166
"Float8Tensor",
165167
# smooth quant - subject to change
166168
"get_scale",

torchao/quantization/quant_api.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
Int4MarlinSparseTensor,
7676
Int4PreshuffledTensor,
7777
Int4Tensor,
78+
IntxUnpackedTensor,
7879
QuantizeTensorToFloat8Kwargs,
7980
)
8081
from torchao.quantization.transform_module import (
@@ -454,6 +455,10 @@ def _linear_extra_repr(self):
454455
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"
455456

456457

458+
def _embedding_extra_repr(self):
459+
return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}"
460+
461+
457462
def _get_linear_subclass_inserter(
458463
constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs
459464
):
@@ -1987,6 +1992,8 @@ class IntxWeightOnlyConfig(AOBaseConfig):
19871992
mapping_type: MappingType = MappingType.SYMMETRIC
19881993
scale_dtype: Optional[torch.dtype] = None
19891994
layout: Layout = QDQLayout()
1995+
packing_format: PackingFormat = PackingFormat.UNPACKED_TO_INT8
1996+
version: int = 1
19901997

19911998
def __post_init__(self):
19921999
torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig")
@@ -2005,16 +2012,13 @@ def __post_init__(self):
20052012
)
20062013

20072014

2008-
@register_quantize_module_handler(IntxWeightOnlyConfig)
2009-
def _intx_weight_only_transform(
2010-
module: torch.nn.Module, config: IntxWeightOnlyConfig
2011-
) -> torch.nn.Module:
2012-
weight = module.weight
2015+
def _intx_weight_only_quantize_tensor(weight, config):
20132016
weight_dtype = config.weight_dtype
20142017
granularity = config.granularity
20152018
mapping_type = config.mapping_type
20162019
scale_dtype = config.scale_dtype
20172020
layout = config.layout
2021+
packing_format = config.packing_format
20182022

20192023
assert weight.dim() == 2, (
20202024
f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}"
@@ -2029,11 +2033,28 @@ def _intx_weight_only_transform(
20292033
else:
20302034
raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}")
20312035

2036+
block_size = (1, group_size)
2037+
2038+
if config.version == 2:
2039+
if config.packing_format == PackingFormat.UNPACKED_TO_INT8:
2040+
new_weight = IntxUnpackedTensor.from_hp(
2041+
weight,
2042+
block_size,
2043+
weight_dtype,
2044+
mapping_type=mapping_type,
2045+
)
2046+
if scale_dtype is not None and scale_dtype != weight.dtype:
2047+
new_weight.scale = new_weight.scale.to(scale_dtype).to(weight.dtype)
2048+
return new_weight
2049+
else:
2050+
raise ValueError(f"Unsupported packing format: {packing_format}")
2051+
2052+
# Version 1
20322053
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype]
20332054
weight = to_affine_quantized_intx(
20342055
input_float=weight,
20352056
mapping_type=mapping_type,
2036-
block_size=(1, group_size),
2057+
block_size=block_size,
20372058
target_dtype=torch.int8,
20382059
quant_min=quant_min,
20392060
quant_max=quant_max,
@@ -2043,7 +2064,25 @@ def _intx_weight_only_transform(
20432064
zero_point_domain=ZeroPointDomain.INT,
20442065
_layout=layout,
20452066
)
2046-
module.weight = torch.nn.Parameter(weight, requires_grad=False)
2067+
return weight
2068+
2069+
2070+
@register_quantize_module_handler(IntxWeightOnlyConfig)
2071+
def _intx_weight_only_transform(
2072+
module: torch.nn.Module, config: IntxWeightOnlyConfig
2073+
) -> torch.nn.Module:
2074+
assert hasattr(module, "weight"), (
2075+
"applying intx weight only quant requires module to have weight attribute"
2076+
+ " but {module} does not have one"
2077+
)
2078+
new_weight = _intx_weight_only_quantize_tensor(module.weight, config)
2079+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
2080+
2081+
if isinstance(module, nn.Linear):
2082+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
2083+
elif isinstance(module, nn.Embedding):
2084+
module.extra_repr = types.MethodType(_embedding_extra_repr, module)
2085+
20472086
return module
20482087

20492088

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,8 @@ class PackingFormat(str, Enum):
3535
marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization
3636
"""
3737
MARLIN_SPARSE = "marlin_sparse"
38+
39+
"""
40+
Unpacked means the subbyte quantized data is stored as int8
41+
"""
42+
UNPACKED_TO_INT8 = "unpacked_to_int8"

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
from .int4.int4_tensor import (
1212
Int4Tensor,
1313
)
14+
from .intx.intx_unpacked_tensor import (
15+
IntxUnpackedTensor,
16+
)
1417

1518
__all__ = [
1619
"Int4Tensor",
1720
"Int4PreshuffledTensor",
1821
"Int4MarlinSparseTensor",
1922
"Float8Tensor",
2023
"QuantizeTensorToFloat8Kwargs",
24+
"IntxUnpackedTensor",
2125
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .intx_unpacked_tensor import IntxUnpackedTensor
2+
3+
__all__ = [
4+
"IntxUnpackedTensor",
5+
]

0 commit comments

Comments
 (0)