Skip to content

Commit c9eb459

Browse files
committed
add intx unpacked tensor
1 parent e6b38bb commit c9eb459

File tree

6 files changed

+505
-7
lines changed

6 files changed

+505
-7
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
IntxWeightOnlyConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.granularity import PerGroup
20+
from torchao.quantization.utils import compute_error
21+
from torchao.utils import (
22+
TORCH_VERSION_AT_LEAST_2_8,
23+
)
24+
25+
26+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
27+
class TestIntxUnpackedTensor(TestCase):
28+
def setUp(self):
29+
self.config = IntxWeightOnlyConfig(
30+
weight_dtype=torch.int4,
31+
granularity=PerGroup(32),
32+
VERSION=2,
33+
)
34+
35+
def test_linear(self):
36+
dtype = torch.bfloat16
37+
device = "cpu"
38+
input = torch.randn(1, 128, dtype=dtype, device=device)
39+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
40+
original = linear(input)
41+
quantize_(linear, self.config)
42+
quantized = linear(input)
43+
error = compute_error(original, quantized)
44+
self.assertTrue(error > 20)
45+
46+
def test_slice(self):
47+
dtype = torch.bfloat16
48+
device = "cpu"
49+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
50+
51+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
52+
dummy1.weight = torch.nn.Parameter(
53+
dummy.weight.narrow(0, 0, 64), requires_grad=False
54+
)
55+
56+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
57+
dummy2.weight = torch.nn.Parameter(
58+
dummy.weight.narrow(1, 0, 128), requires_grad=False
59+
)
60+
61+
quantize_(dummy, self.config)
62+
weight1 = dummy.weight.narrow(0, 0, 64)
63+
weight2 = dummy.weight.narrow(1, 0, 128)
64+
65+
self.assertEqual(weight1.int_data, dummy.weight.int_data.narrow(0, 0, 64))
66+
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64))
67+
68+
self.assertEqual(weight2.int_data, dummy.weight.int_data.narrow(1, 0, 128))
69+
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(1, 0, 4))
70+
71+
# check for sliced weight, before and after float8 quantization
72+
# does not differ too much
73+
input = torch.randn(2, 256, dtype=dtype, device=device)
74+
res_ref = dummy1(input)
75+
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
76+
res = dummy(input)
77+
assert compute_error(res, res_ref) > 20
78+
79+
input = torch.randn(2, 128, dtype=dtype, device=device)
80+
res_ref = dummy2(input)
81+
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
82+
res = dummy(input)
83+
assert compute_error(res, res_ref) > 15
84+
85+
def test_slice_and_copy_(self):
86+
device = "cpu"
87+
l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16)
88+
l.weight = torch.nn.Parameter(
89+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device=device)
90+
)
91+
quantize_(l, self.config)
92+
param = l.weight
93+
param_data = param.data
94+
param_data = param_data.narrow(0, 0, 512)
95+
assert param.data.int_data.data_ptr() == param_data.int_data.data_ptr()
96+
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
97+
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
98+
orig_value = param.data.int_data[0][0].item()
99+
100+
# dummy_l has random input (shouldn't be 0)
101+
dummy_l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16)
102+
quantize_(dummy_l, self.config)
103+
quantized = dummy_l.weight
104+
quantized = quantized.narrow(0, 0, 512)
105+
106+
param_data.copy_(quantized)
107+
108+
# making sure param.data is updated
109+
assert param.data.int_data[0][0] != orig_value
110+
111+
def test_to_dtype(self):
112+
activations_bf16 = torch.randn(1, 128, dtype=torch.bfloat16)
113+
activations_fp32 = torch.randn(1, 128, dtype=torch.float32)
114+
activations_fp16 = torch.randn(1, 128, dtype=torch.float16)
115+
116+
linear = torch.nn.Linear(128, 256)
117+
quantize_(linear, self.config)
118+
119+
linear.to(dtype=torch.float16)
120+
linear(activations_fp16)
121+
122+
linear.to(dtype=torch.float32)
123+
linear(activations_fp32)
124+
125+
linear.to(dtype=torch.bfloat16)
126+
linear(activations_bf16)
127+
128+
def test_export(self):
129+
linear = torch.nn.Linear(128, 256)
130+
quantize_(linear, self.config)
131+
ep = torch.export.export(linear, (torch.randn(1, 128),))
132+
assert "torch.ops.torchao.dequantize_affine.default" in ep.graph_module.code
133+
134+
135+
if __name__ == "__main__":
136+
run_tests()

torchao/quantization/quant_api.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
Int4MarlinSparseTensor,
7676
Int4PreshuffledTensor,
7777
Int4Tensor,
78+
IntxUnpackedTensor,
7879
QuantizeTensorToFloat8Kwargs,
7980
)
8081
from torchao.quantization.transform_module import (
@@ -1987,6 +1988,8 @@ class IntxWeightOnlyConfig(AOBaseConfig):
19871988
mapping_type: MappingType = MappingType.SYMMETRIC
19881989
scale_dtype: Optional[torch.dtype] = None
19891990
layout: Layout = QDQLayout()
1991+
packing_format: PackingFormat = PackingFormat.UNPACKED
1992+
VERSION: int = 1
19901993

19911994
def __post_init__(self):
19921995
torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig")
@@ -2005,16 +2008,13 @@ def __post_init__(self):
20052008
)
20062009

20072010

2008-
@register_quantize_module_handler(IntxWeightOnlyConfig)
2009-
def _intx_weight_only_transform(
2010-
module: torch.nn.Module, config: IntxWeightOnlyConfig
2011-
) -> torch.nn.Module:
2012-
weight = module.weight
2011+
def _intx_weight_only_quantize_tensor(weight, config):
20132012
weight_dtype = config.weight_dtype
20142013
granularity = config.granularity
20152014
mapping_type = config.mapping_type
20162015
scale_dtype = config.scale_dtype
20172016
layout = config.layout
2017+
packing_format = config.packing_format
20182018

20192019
assert weight.dim() == 2, (
20202020
f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}"
@@ -2029,11 +2029,28 @@ def _intx_weight_only_transform(
20292029
else:
20302030
raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}")
20312031

2032+
block_size = (1, group_size)
2033+
2034+
if config.VERSION == 2:
2035+
if config.packing_format == PackingFormat.UNPACKED:
2036+
new_weight = IntxUnpackedTensor.from_float(
2037+
weight,
2038+
block_size,
2039+
weight_dtype,
2040+
mapping_type=mapping_type,
2041+
)
2042+
if scale_dtype is not None and scale_dtype != weight.dtype:
2043+
new_weight.scale = new_weight.scale.to(scale_dtype).to(weight.dtype)
2044+
return new_weight
2045+
else:
2046+
raise ValueError(f"Unsupported packing format: {packing_format}")
2047+
2048+
# Version 1
20322049
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype]
20332050
weight = to_affine_quantized_intx(
20342051
input_float=weight,
20352052
mapping_type=mapping_type,
2036-
block_size=(1, group_size),
2053+
block_size=block_size,
20372054
target_dtype=torch.int8,
20382055
quant_min=quant_min,
20392056
quant_max=quant_max,
@@ -2043,7 +2060,19 @@ def _intx_weight_only_transform(
20432060
zero_point_domain=ZeroPointDomain.INT,
20442061
_layout=layout,
20452062
)
2046-
module.weight = torch.nn.Parameter(weight, requires_grad=False)
2063+
2064+
2065+
@register_quantize_module_handler(IntxWeightOnlyConfig)
2066+
def _intx_weight_only_transform(
2067+
module: torch.nn.Module, config: IntxWeightOnlyConfig
2068+
) -> torch.nn.Module:
2069+
assert hasattr(module, "weight"), (
2070+
"applying intx weight only quant requires module to have weight attribute"
2071+
+ " but {module} does not have one"
2072+
)
2073+
new_weight = _intx_weight_only_quantize_tensor(module.weight, config)
2074+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
2075+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
20472076
return module
20482077

20492078

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,8 @@ class PackingFormat(str, Enum):
3535
marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization
3636
"""
3737
MARLIN_SPARSE = "marlin_sparse"
38+
39+
"""
40+
Unpacked means the subbyte quantized data is stored as int8
41+
"""
42+
UNPACKED = "unpacked"

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
from .int4.int4_tensor import (
1212
Int4Tensor,
1313
)
14+
from .intx.intx_unpacked_tensor import (
15+
IntxUnpackedTensor,
16+
)
1417

1518
__all__ = [
1619
"Int4Tensor",
1720
"Int4PreshuffledTensor",
1821
"Int4MarlinSparseTensor",
1922
"Float8Tensor",
2023
"QuantizeTensorToFloat8Kwargs",
24+
"IntxUnpackedTensor",
2125
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .intx_unpacked_tensor import IntxUnpackedTensor
2+
3+
__all__ = [
4+
"IntxUnpackedTensor",
5+
]

0 commit comments

Comments
 (0)