Skip to content

Commit e68c9fa

Browse files
committed
add intx unpacked tensor
1 parent 6cfa477 commit e68c9fa

File tree

6 files changed

+505
-7
lines changed

6 files changed

+505
-7
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
IntxWeightOnlyConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.granularity import PerGroup
20+
from torchao.quantization.utils import compute_error
21+
from torchao.utils import (
22+
TORCH_VERSION_AT_LEAST_2_8,
23+
)
24+
25+
26+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
27+
class TestIntxUnpackedTensor(TestCase):
28+
def setUp(self):
29+
self.config = IntxWeightOnlyConfig(
30+
weight_dtype=torch.int4,
31+
granularity=PerGroup(32),
32+
VERSION=2,
33+
)
34+
35+
def test_linear(self):
36+
dtype = torch.bfloat16
37+
device = "cpu"
38+
input = torch.randn(1, 128, dtype=dtype, device=device)
39+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
40+
original = linear(input)
41+
quantize_(linear, self.config)
42+
quantized = linear(input)
43+
error = compute_error(original, quantized)
44+
self.assertTrue(error > 20)
45+
46+
def test_slice(self):
47+
dtype = torch.bfloat16
48+
device = "cpu"
49+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
50+
51+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
52+
dummy1.weight = torch.nn.Parameter(
53+
dummy.weight.narrow(0, 0, 64), requires_grad=False
54+
)
55+
56+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
57+
dummy2.weight = torch.nn.Parameter(
58+
dummy.weight.narrow(1, 0, 128), requires_grad=False
59+
)
60+
61+
quantize_(dummy, self.config)
62+
weight1 = dummy.weight.narrow(0, 0, 64)
63+
weight2 = dummy.weight.narrow(1, 0, 128)
64+
65+
self.assertEqual(weight1.int_data, dummy.weight.int_data.narrow(0, 0, 64))
66+
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64))
67+
68+
self.assertEqual(weight2.int_data, dummy.weight.int_data.narrow(1, 0, 128))
69+
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(1, 0, 4))
70+
71+
# check for sliced weight, before and after float8 quantization
72+
# does not differ too much
73+
input = torch.randn(2, 256, dtype=dtype, device=device)
74+
res_ref = dummy1(input)
75+
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
76+
res = dummy(input)
77+
assert compute_error(res, res_ref) > 20
78+
79+
input = torch.randn(2, 128, dtype=dtype, device=device)
80+
res_ref = dummy2(input)
81+
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
82+
res = dummy(input)
83+
assert compute_error(res, res_ref) > 15
84+
85+
def test_slice_and_copy_(self):
86+
device = "cpu"
87+
l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16)
88+
l.weight = torch.nn.Parameter(
89+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device=device)
90+
)
91+
quantize_(l, self.config)
92+
param = l.weight
93+
param_data = param.data
94+
param_data = param_data.narrow(0, 0, 512)
95+
assert param.data.int_data.data_ptr() == param_data.int_data.data_ptr()
96+
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
97+
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
98+
orig_value = param.data.int_data[0][0].item()
99+
100+
# dummy_l has random input (shouldn't be 0)
101+
dummy_l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16)
102+
quantize_(dummy_l, self.config)
103+
quantized = dummy_l.weight
104+
quantized = quantized.narrow(0, 0, 512)
105+
106+
param_data.copy_(quantized)
107+
108+
# making sure param.data is updated
109+
assert param.data.int_data[0][0] != orig_value
110+
111+
def test_to_dtype(self):
112+
activations_bf16 = torch.randn(1, 128, dtype=torch.bfloat16)
113+
activations_fp32 = torch.randn(1, 128, dtype=torch.float32)
114+
activations_fp16 = torch.randn(1, 128, dtype=torch.float16)
115+
116+
linear = torch.nn.Linear(128, 256)
117+
quantize_(linear, self.config)
118+
119+
linear.to(dtype=torch.float16)
120+
linear(activations_fp16)
121+
122+
linear.to(dtype=torch.float32)
123+
linear(activations_fp32)
124+
125+
linear.to(dtype=torch.bfloat16)
126+
linear(activations_bf16)
127+
128+
def test_export(self):
129+
linear = torch.nn.Linear(128, 256)
130+
quantize_(linear, self.config)
131+
ep = torch.export.export(linear, (torch.randn(1, 128),))
132+
assert "torch.ops.torchao.dequantize_affine.default" in ep.graph_module.code
133+
134+
135+
if __name__ == "__main__":
136+
run_tests()

torchao/quantization/quant_api.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
Float8Tensor,
7777
Int4PreshuffledTensor,
7878
Int4Tensor,
79+
IntxUnpackedTensor,
7980
QuantizeTensorToFloat8Kwargs,
8081
)
8182
from torchao.quantization.transform_module import (
@@ -2060,6 +2061,8 @@ class IntxWeightOnlyConfig(AOBaseConfig):
20602061
mapping_type: MappingType = MappingType.SYMMETRIC
20612062
scale_dtype: Optional[torch.dtype] = None
20622063
layout: Layout = QDQLayout()
2064+
packing_format: PackingFormat = PackingFormat.UNPACKED
2065+
VERSION: int = 1
20632066

20642067
def __post_init__(self):
20652068
assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+"
@@ -2078,16 +2081,13 @@ def __post_init__(self):
20782081
)
20792082

20802083

2081-
@register_quantize_module_handler(IntxWeightOnlyConfig)
2082-
def _intx_weight_only_transform(
2083-
module: torch.nn.Module, config: IntxWeightOnlyConfig
2084-
) -> torch.nn.Module:
2085-
weight = module.weight
2084+
def _intx_weight_only_quantize_tensor(weight, config):
20862085
weight_dtype = config.weight_dtype
20872086
granularity = config.granularity
20882087
mapping_type = config.mapping_type
20892088
scale_dtype = config.scale_dtype
20902089
layout = config.layout
2090+
packing_format = config.packing_format
20912091

20922092
assert weight.dim() == 2, (
20932093
f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}"
@@ -2102,11 +2102,28 @@ def _intx_weight_only_transform(
21022102
else:
21032103
raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}")
21042104

2105+
block_size = (1, group_size)
2106+
2107+
if config.VERSION == 2:
2108+
if config.packing_format == PackingFormat.UNPACKED:
2109+
new_weight = IntxUnpackedTensor.from_float(
2110+
weight,
2111+
block_size,
2112+
weight_dtype,
2113+
mapping_type=mapping_type,
2114+
)
2115+
if scale_dtype is not None and scale_dtype != weight.dtype:
2116+
new_weight.scale = new_weight.scale.to(scale_dtype).to(weight.dtype)
2117+
return new_weight
2118+
else:
2119+
raise ValueError(f"Unsupported packing format: {packing_format}")
2120+
2121+
# Version 1
21052122
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype]
21062123
weight = to_affine_quantized_intx(
21072124
input_float=weight,
21082125
mapping_type=mapping_type,
2109-
block_size=(1, group_size),
2126+
block_size=block_size,
21102127
target_dtype=torch.int8,
21112128
quant_min=quant_min,
21122129
quant_max=quant_max,
@@ -2116,7 +2133,19 @@ def _intx_weight_only_transform(
21162133
zero_point_domain=ZeroPointDomain.INT,
21172134
_layout=layout,
21182135
)
2119-
module.weight = torch.nn.Parameter(weight, requires_grad=False)
2136+
2137+
2138+
@register_quantize_module_handler(IntxWeightOnlyConfig)
2139+
def _intx_weight_only_transform(
2140+
module: torch.nn.Module, config: IntxWeightOnlyConfig
2141+
) -> torch.nn.Module:
2142+
assert hasattr(module, "weight"), (
2143+
"applying intx weight only quant requires module to have weight attribute"
2144+
+ " but {module} does not have one"
2145+
)
2146+
new_weight = _intx_weight_only_quantize_tensor(module.weight, config)
2147+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
2148+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
21202149
return module
21212150

21222151

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ class PackingFormat(str, Enum):
3030
preshuffled is referring to the preshuffled format used by fbgemm kernels
3131
"""
3232
PRESHUFFLED = "preshuffled"
33+
34+
"""
35+
Unpacked means the subbyte quantized data is stored as int8
36+
"""
37+
UNPACKED = "unpacked"

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88
from .int4.int4_tensor import (
99
Int4Tensor,
1010
)
11+
from .intx.intx_unpacked_tensor import (
12+
IntxUnpackedTensor,
13+
)
1114

1215
__all__ = [
1316
"Int4Tensor",
1417
"Int4PreshuffledTensor",
1518
"Float8Tensor",
1619
"QuantizeTensorToFloat8Kwargs",
20+
"IntxUnpackedTensor",
1721
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .intx_unpacked_tensor import IntxUnpackedTensor
2+
3+
__all__ = [
4+
"IntxUnpackedTensor",
5+
]

0 commit comments

Comments
 (0)