Skip to content

Commit 143fe91

Browse files
committed
up
1 parent bd29f1c commit 143fe91

File tree

4 files changed

+86
-82
lines changed

4 files changed

+86
-82
lines changed

test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def test_slice(self):
7373
weight1 = dummy.weight.narrow(0, 0, 64)
7474
weight2 = dummy.weight.narrow(1, 0, 128)
7575

76-
self.assertEqual(weight1.int_data, dummy.weight.int_data.narrow(0, 0, 64))
76+
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64))
7777
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64))
7878

79-
self.assertEqual(weight2.int_data, dummy.weight.int_data.narrow(1, 0, 128))
79+
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128))
8080
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(1, 0, 4))
8181

8282
# check for sliced weight, before and after float8 quantization
@@ -103,10 +103,10 @@ def test_slice_and_copy_(self):
103103
param = l.weight
104104
param_data = param.data
105105
param_data = param_data.narrow(0, 0, 512)
106-
assert param.data.int_data.data_ptr() == param_data.int_data.data_ptr()
106+
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
107107
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
108108
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
109-
orig_value = param.data.int_data[0][0].item()
109+
orig_value = param.data.qdata[0][0].item()
110110

111111
# dummy_l has random input (shouldn't be 0)
112112
dummy_l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16)
@@ -117,7 +117,7 @@ def test_slice_and_copy_(self):
117117
param_data.copy_(quantized)
118118

119119
# making sure param.data is updated
120-
assert param.data.int_data[0][0] != orig_value
120+
assert param.data.qdata[0][0] != orig_value
121121

122122
def test_to_dtype(self):
123123
activations_bf16 = torch.randn(1, 128, dtype=torch.bfloat16)

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
Int4MarlinSparseTensor,
9494
Int4PreshuffledTensor,
9595
Int4Tensor,
96+
IntxUnpackedTensor,
9697
)
9798
from .smoothquant import (
9899
SmoothFakeDynamicallyQuantizedLinear,
@@ -161,6 +162,7 @@
161162
"Int4Tensor",
162163
"Int4PreshuffledTensor",
163164
"Int4MarlinSparseTensor",
165+
"IntxUnpackedTensor",
164166
"Float8Tensor",
165167
# smooth quant - subject to change
166168
"get_scale",

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class PackingFormat(str, Enum):
3535
marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization
3636
"""
3737
MARLIN_SPARSE = "marlin_sparse"
38-
38+
3939
"""
4040
Unpacked means the subbyte quantized data is stored as int8
4141
"""

torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py

Lines changed: 78 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from typing import List, Optional, Tuple
8+
from typing import List, Tuple
99

1010
import torch
1111
from torch.utils._python_dispatch import return_and_correct_aliasing
1212

1313
from torchao.quantization.quant_primitives import (
14-
_DTYPE_TO_BIT_WIDTH,
1514
_DTYPE_TO_QVALUE_BOUNDS,
1615
MappingType,
1716
choose_qparams_affine,
@@ -36,12 +35,14 @@
3635
class IntxUnpackedTensor(TorchAOBaseTensor):
3736
"""
3837
intx quantization with unpacked format. Subbyte quantized data is represented as int8.
38+
The range of the quantized values are restricted to the quant_min and quant_max of the target_dtype, e.g.,
39+
if target_dtype=torch.int4, qdata will be an int8 tensor with values in [-8, 7].
3940
Quantization is represented in a decomposed way.
4041
This format is inteded for torch.export use cases.
4142
4243
Tensor Attributes:
43-
int_data: int data for quantization.
44-
dtype is int8
44+
qdata: int data for quantization.
45+
dtype is int8, but the range of the qdata is determined by target_dtype
4546
Shape is the same as original Tensor: (n, k) for 2D tensor
4647
scale: block scales for quantization
4748
dtype is the same as the original Tensor dtype.
@@ -51,72 +52,60 @@ class IntxUnpackedTensor(TorchAOBaseTensor):
5152
Shape is (n // block_size[0], k // block_size[1]) for 2D tensor
5253
5354
Non-Tensor Attributes:
54-
bit_width: the bit width for quantization (can be 1 - 8)
55+
target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8)
5556
block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
5657
"""
5758

58-
tensor_data_names = ["int_data", "scale", "zero_point"]
59-
tensor_attribute_names = ["bit_width", "block_size"]
59+
tensor_data_names = ["qdata", "scale", "zero_point"]
60+
tensor_attribute_names = ["target_dtype", "block_size"]
6061

61-
def __new__(cls, int_data, scale, zero_point, bit_width, block_size=None):
62+
def __new__(cls, qdata, scale, zero_point, target_dtype, block_size=None):
6263
kwargs = {}
63-
kwargs["device"] = int_data.device
64+
kwargs["device"] = qdata.device
6465
kwargs["dtype"] = scale.dtype
6566
kwargs["requires_grad"] = False
66-
shape = int_data.shape
67+
shape = qdata.shape
6768
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
6869

6970
def __init__(
7071
self,
71-
int_data,
72+
qdata,
7273
scale,
7374
zero_point,
74-
bit_width,
75-
block_size: Optional[Tuple[int]] = None,
75+
target_dtype,
76+
block_size: Tuple[int],
7677
):
77-
# Check plain data and infer block_size from shapes
78-
if block_size is None:
79-
assert scale.ndim == int_data.ndim
80-
assert zero_point.ndim == int_data.ndim
81-
block_size = []
82-
for i in range(int_data.ndim):
83-
assert scale.shape[i] == zero_point.shape[i]
84-
n_blocks = scale.shape[i]
85-
assert int_data.shape[i] % n_blocks == 0
86-
block_size.append(int_data.shape[i] // n_blocks)
87-
block_size = tuple(block_size)
88-
else:
89-
assert len(block_size) == int_data.ndim
90-
n_blocks = []
91-
for i in range(len(block_size)):
92-
assert int_data.shape[i] % block_size[i] == 0
93-
n_blocks.append(int_data.shape[i] // block_size[i])
94-
scale = scale.reshape(*n_blocks)
95-
zero_point = zero_point.reshape(*n_blocks)
96-
97-
assert block_size is not None
98-
assert isinstance(block_size, tuple)
99-
assert bit_width >= 1 and bit_width <= 8
100-
101-
self.int_data = int_data
78+
assert qdata.dtype == torch.int8, (
79+
f"qdata dtype must be int8, but got {qdata.dtype}"
80+
)
81+
assert scale.dtype in _FLOAT_TYPES, (
82+
f"scale dtype must be one of {_FLOAT_TYPES}, but got {scale.dtype}"
83+
)
84+
assert zero_point.dtype in _FLOAT_TYPES or zero_point.dtype == torch.int8, (
85+
f"zero_point dtype must be {torch.int8} or one of {_FLOAT_TYPES}, but got {zero_point.dtype}"
86+
)
87+
88+
assert target_dtype in [
89+
getattr(torch, f"int{bit_width}") for bit_width in range(1, 9)
90+
]
91+
92+
assert len(block_size) == qdata.ndim
93+
n_blocks = []
94+
for i in range(len(block_size)):
95+
assert qdata.shape[i] % block_size[i] == 0
96+
n_blocks.append(qdata.shape[i] // block_size[i])
97+
scale = scale.reshape(*n_blocks)
98+
zero_point = zero_point.reshape(*n_blocks)
99+
100+
self.qdata = qdata
102101
self.scale = scale
103102
self.zero_point = zero_point
104103

105-
self.bit_width = bit_width
104+
self.target_dtype = target_dtype
106105
self.block_size = block_size
107106

108-
def __repr__(self):
109-
repr_fields = (
110-
self.tensor_data_names
111-
+ self.tensor_attribute_names
112-
+ ["shape", "device", "dtype", "require_grad"]
113-
)
114-
inner_repr = [f"{attr}={getattr(self, attr)}" for attr in repr_fields]
115-
inner_repr = ", ".join(inner_repr)
116-
return f"{self.__class__.__name__}({inner_repr}))"
117-
118107
def _quantization_type(self):
119-
return f"bit_width={self.bit_width}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}"
108+
return f"target_dtype={self.target_dtype}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}"
120109

121110
def _has_float_zero_point(self) -> bool:
122111
return self.zero_point.dtype in _FLOAT_TYPES
@@ -126,40 +115,44 @@ def to(self, *args, **kwargs):
126115
device = kwargs.pop("device")
127116
dtype = kwargs.pop("dtype")
128117
assert dtype in _FLOAT_TYPES
129-
return self.__class__(
130-
self.int_data.to(device),
118+
return IntxUnpackedTensor(
119+
self.qdata.to(device),
131120
self.scale.to(device=device, dtype=dtype),
132121
self.zero_point.to(device=device, dtype=dtype)
133122
if self._has_float_zero_point()
134123
else self.zero_point.to(device),
135-
self.bit_width,
124+
self.target_dtype,
136125
self.block_size,
137126
)
138127

139128
@classmethod
140129
def from_hp(
141130
cls,
142-
float_tensor: torch.Tensor,
131+
hp_tensor: torch.Tensor,
143132
block_size: Tuple[int],
144-
dtype: torch.dtype,
133+
target_dtype: torch.dtype,
145134
*,
146135
mapping_type: MappingType = MappingType.SYMMETRIC,
147136
):
148137
"""
149138
Create an IntxUnpackedTensor from a high-precision tensor
150139
"""
151-
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[dtype]
152-
bit_width = _DTYPE_TO_BIT_WIDTH[dtype]
140+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype]
153141
scale, zero_point = choose_qparams_affine(
154-
float_tensor,
142+
hp_tensor,
155143
mapping_type,
156144
block_size,
157145
target_dtype=torch.int8,
158146
quant_min=qmin,
159147
quant_max=qmax,
160148
)
161-
int_data = quantize_affine(
162-
float_tensor,
149+
if zero_point.dtype == torch.int32:
150+
int8_min, int8_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8]
151+
assert zero_point.min().item() >= int8_min
152+
assert zero_point.max().item() <= int8_max
153+
zero_point = zero_point.to(torch.int8)
154+
qdata = quantize_affine(
155+
hp_tensor,
163156
block_size,
164157
scale,
165158
zero_point,
@@ -168,20 +161,17 @@ def from_hp(
168161
quant_max=qmax,
169162
)
170163
return IntxUnpackedTensor(
171-
int_data=int_data,
164+
qdata=qdata,
172165
scale=scale,
173166
zero_point=zero_point,
174-
bit_width=bit_width,
167+
target_dtype=target_dtype,
175168
block_size=block_size,
176169
)
177170

178-
def get_plain(self):
179-
return self.int_data, self.scale, self.zero_point
180-
181171
def dequantize(self):
182-
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[getattr(torch, f"int{self.bit_width}")]
172+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.target_dtype]
183173
return dequantize_affine(
184-
self.int_data,
174+
self.qdata,
185175
self.block_size,
186176
self.scale,
187177
self.zero_point,
@@ -202,7 +192,10 @@ def _(func, types, args, kwargs):
202192
args[1],
203193
args[2] if len(args) > 2 else None,
204194
)
205-
weight_tensor = weight_tensor.dequantize()
195+
if isinstance(input_tensor, IntxUnpackedTensor):
196+
input_tensor = input_tensor.dequantize()
197+
if isinstance(weight_tensor, IntxUnpackedTensor):
198+
weight_tensor = weight_tensor.dequantize()
206199
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
207200

208201

@@ -227,14 +220,14 @@ def _(func, types, args, kwargs):
227220
# Otherwise the sliced tensor cannot be represented as a IntxUnpackedTensor
228221
# For example, if block_size = 4, we might have:
229222
#
230-
# int_data: i i i i | i i i i
223+
# qdata: i i i i | i i i i
231224
# scale: s s
232225
#
233-
# If we set start = 2 and end = 8, then the int_data slice is:
226+
# If we set start = 2 and end = 8, then the qdata slice is:
234227
#
235-
# int_data_slice: i i (i i | i i i i)
228+
# qdata_slice: i i (i i | i i i i)
236229
#
237-
# But then the block_size for the first two int_data in the slice is 2
230+
# But then the block_size for the first two qdata in the slice is 2
238231
# and remaining blocks have size 4. This cannot be represented
239232
# with the metadata we store in an IntxUnpackedTensor, which requires uniform blocking
240233

@@ -248,15 +241,24 @@ def _(func, types, args, kwargs):
248241
)
249242
end_scale = end // self.block_size[dim]
250243

251-
int_data = aten.slice.Tensor(self.int_data, dim, start, end, step)
244+
qdata = aten.slice.Tensor(self.qdata, dim, start, end, step)
252245
scale = aten.slice.Tensor(self.scale, dim, start_scale, end_scale, step)
253246
zero_point = aten.slice.Tensor(self.zero_point, dim, start_scale, end_scale, step)
254247

255-
new = self.__class__(
256-
int_data,
248+
new_block_size = []
249+
for i in range(qdata.ndim):
250+
assert scale.shape[i] == zero_point.shape[i]
251+
n_blocks = scale.shape[i]
252+
assert qdata.shape[i] % n_blocks == 0
253+
new_block_size.append(qdata.shape[i] // n_blocks)
254+
new_block_size = tuple(new_block_size)
255+
256+
new = IntxUnpackedTensor(
257+
qdata,
257258
scale,
258259
zero_point,
259-
self.bit_width,
260+
self.target_dtype,
261+
new_block_size,
260262
)
261263
return return_and_correct_aliasing(func, args, kwargs, new)
262264

0 commit comments

Comments
 (0)