5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
8
- from typing import List , Optional , Tuple
8
+ from typing import List , Tuple
9
9
10
10
import torch
11
11
from torch .utils ._python_dispatch import return_and_correct_aliasing
12
12
13
13
from torchao .quantization .quant_primitives import (
14
- _DTYPE_TO_BIT_WIDTH ,
15
14
_DTYPE_TO_QVALUE_BOUNDS ,
16
15
MappingType ,
17
16
choose_qparams_affine ,
36
35
class IntxUnpackedTensor (TorchAOBaseTensor ):
37
36
"""
38
37
intx quantization with unpacked format. Subbyte quantized data is represented as int8.
38
+ The range of the quantized values are restricted to the quant_min and quant_max of the target_dtype, e.g.,
39
+ if target_dtype=torch.int4, qdata will be an int8 tensor with values in [-8, 7].
39
40
Quantization is represented in a decomposed way.
40
41
This format is inteded for torch.export use cases.
41
42
42
43
Tensor Attributes:
43
- int_data : int data for quantization.
44
- dtype is int8
44
+ qdata : int data for quantization.
45
+ dtype is int8, but the range of the qdata is determined by target_dtype
45
46
Shape is the same as original Tensor: (n, k) for 2D tensor
46
47
scale: block scales for quantization
47
48
dtype is the same as the original Tensor dtype.
@@ -51,72 +52,60 @@ class IntxUnpackedTensor(TorchAOBaseTensor):
51
52
Shape is (n // block_size[0], k // block_size[1]) for 2D tensor
52
53
53
54
Non-Tensor Attributes:
54
- bit_width: the bit width for quantization (can be 1 - 8 )
55
+ target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8 )
55
56
block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
56
57
"""
57
58
58
- tensor_data_names = ["int_data " , "scale" , "zero_point" ]
59
- tensor_attribute_names = ["bit_width " , "block_size" ]
59
+ tensor_data_names = ["qdata " , "scale" , "zero_point" ]
60
+ tensor_attribute_names = ["target_dtype " , "block_size" ]
60
61
61
- def __new__ (cls , int_data , scale , zero_point , bit_width , block_size = None ):
62
+ def __new__ (cls , qdata , scale , zero_point , target_dtype , block_size = None ):
62
63
kwargs = {}
63
- kwargs ["device" ] = int_data .device
64
+ kwargs ["device" ] = qdata .device
64
65
kwargs ["dtype" ] = scale .dtype
65
66
kwargs ["requires_grad" ] = False
66
- shape = int_data .shape
67
+ shape = qdata .shape
67
68
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
68
69
69
70
def __init__ (
70
71
self ,
71
- int_data ,
72
+ qdata ,
72
73
scale ,
73
74
zero_point ,
74
- bit_width ,
75
- block_size : Optional [ Tuple [int ]] = None ,
75
+ target_dtype ,
76
+ block_size : Tuple [int ],
76
77
):
77
- # Check plain data and infer block_size from shapes
78
- if block_size is None :
79
- assert scale .ndim == int_data .ndim
80
- assert zero_point .ndim == int_data .ndim
81
- block_size = []
82
- for i in range (int_data .ndim ):
83
- assert scale .shape [i ] == zero_point .shape [i ]
84
- n_blocks = scale .shape [i ]
85
- assert int_data .shape [i ] % n_blocks == 0
86
- block_size .append (int_data .shape [i ] // n_blocks )
87
- block_size = tuple (block_size )
88
- else :
89
- assert len (block_size ) == int_data .ndim
90
- n_blocks = []
91
- for i in range (len (block_size )):
92
- assert int_data .shape [i ] % block_size [i ] == 0
93
- n_blocks .append (int_data .shape [i ] // block_size [i ])
94
- scale = scale .reshape (* n_blocks )
95
- zero_point = zero_point .reshape (* n_blocks )
96
-
97
- assert block_size is not None
98
- assert isinstance (block_size , tuple )
99
- assert bit_width >= 1 and bit_width <= 8
100
-
101
- self .int_data = int_data
78
+ assert qdata .dtype == torch .int8 , (
79
+ f"qdata dtype must be int8, but got { qdata .dtype } "
80
+ )
81
+ assert scale .dtype in _FLOAT_TYPES , (
82
+ f"scale dtype must be one of { _FLOAT_TYPES } , but got { scale .dtype } "
83
+ )
84
+ assert zero_point .dtype in _FLOAT_TYPES or zero_point .dtype == torch .int8 , (
85
+ f"zero_point dtype must be { torch .int8 } or one of { _FLOAT_TYPES } , but got { zero_point .dtype } "
86
+ )
87
+
88
+ assert target_dtype in [
89
+ getattr (torch , f"int{ bit_width } " ) for bit_width in range (1 , 9 )
90
+ ]
91
+
92
+ assert len (block_size ) == qdata .ndim
93
+ n_blocks = []
94
+ for i in range (len (block_size )):
95
+ assert qdata .shape [i ] % block_size [i ] == 0
96
+ n_blocks .append (qdata .shape [i ] // block_size [i ])
97
+ scale = scale .reshape (* n_blocks )
98
+ zero_point = zero_point .reshape (* n_blocks )
99
+
100
+ self .qdata = qdata
102
101
self .scale = scale
103
102
self .zero_point = zero_point
104
103
105
- self .bit_width = bit_width
104
+ self .target_dtype = target_dtype
106
105
self .block_size = block_size
107
106
108
- def __repr__ (self ):
109
- repr_fields = (
110
- self .tensor_data_names
111
- + self .tensor_attribute_names
112
- + ["shape" , "device" , "dtype" , "require_grad" ]
113
- )
114
- inner_repr = [f"{ attr } ={ getattr (self , attr )} " for attr in repr_fields ]
115
- inner_repr = ", " .join (inner_repr )
116
- return f"{ self .__class__ .__name__ } ({ inner_repr } ))"
117
-
118
107
def _quantization_type (self ):
119
- return f"bit_width ={ self .bit_width } , block_size={ self .block_size } , shape={ self .shape } , dtype={ self .dtype } , device={ self .device } "
108
+ return f"target_dtype ={ self .target_dtype } , block_size={ self .block_size } , shape={ self .shape } , dtype={ self .dtype } , device={ self .device } "
120
109
121
110
def _has_float_zero_point (self ) -> bool :
122
111
return self .zero_point .dtype in _FLOAT_TYPES
@@ -126,40 +115,44 @@ def to(self, *args, **kwargs):
126
115
device = kwargs .pop ("device" )
127
116
dtype = kwargs .pop ("dtype" )
128
117
assert dtype in _FLOAT_TYPES
129
- return self . __class__ (
130
- self .int_data .to (device ),
118
+ return IntxUnpackedTensor (
119
+ self .qdata .to (device ),
131
120
self .scale .to (device = device , dtype = dtype ),
132
121
self .zero_point .to (device = device , dtype = dtype )
133
122
if self ._has_float_zero_point ()
134
123
else self .zero_point .to (device ),
135
- self .bit_width ,
124
+ self .target_dtype ,
136
125
self .block_size ,
137
126
)
138
127
139
128
@classmethod
140
129
def from_hp (
141
130
cls ,
142
- float_tensor : torch .Tensor ,
131
+ hp_tensor : torch .Tensor ,
143
132
block_size : Tuple [int ],
144
- dtype : torch .dtype ,
133
+ target_dtype : torch .dtype ,
145
134
* ,
146
135
mapping_type : MappingType = MappingType .SYMMETRIC ,
147
136
):
148
137
"""
149
138
Create an IntxUnpackedTensor from a high-precision tensor
150
139
"""
151
- qmin , qmax = _DTYPE_TO_QVALUE_BOUNDS [dtype ]
152
- bit_width = _DTYPE_TO_BIT_WIDTH [dtype ]
140
+ qmin , qmax = _DTYPE_TO_QVALUE_BOUNDS [target_dtype ]
153
141
scale , zero_point = choose_qparams_affine (
154
- float_tensor ,
142
+ hp_tensor ,
155
143
mapping_type ,
156
144
block_size ,
157
145
target_dtype = torch .int8 ,
158
146
quant_min = qmin ,
159
147
quant_max = qmax ,
160
148
)
161
- int_data = quantize_affine (
162
- float_tensor ,
149
+ if zero_point .dtype == torch .int32 :
150
+ int8_min , int8_max = _DTYPE_TO_QVALUE_BOUNDS [torch .int8 ]
151
+ assert zero_point .min ().item () >= int8_min
152
+ assert zero_point .max ().item () <= int8_max
153
+ zero_point = zero_point .to (torch .int8 )
154
+ qdata = quantize_affine (
155
+ hp_tensor ,
163
156
block_size ,
164
157
scale ,
165
158
zero_point ,
@@ -168,20 +161,17 @@ def from_hp(
168
161
quant_max = qmax ,
169
162
)
170
163
return IntxUnpackedTensor (
171
- int_data = int_data ,
164
+ qdata = qdata ,
172
165
scale = scale ,
173
166
zero_point = zero_point ,
174
- bit_width = bit_width ,
167
+ target_dtype = target_dtype ,
175
168
block_size = block_size ,
176
169
)
177
170
178
- def get_plain (self ):
179
- return self .int_data , self .scale , self .zero_point
180
-
181
171
def dequantize (self ):
182
- qmin , qmax = _DTYPE_TO_QVALUE_BOUNDS [getattr ( torch , f"int { self .bit_width } " ) ]
172
+ qmin , qmax = _DTYPE_TO_QVALUE_BOUNDS [self .target_dtype ]
183
173
return dequantize_affine (
184
- self .int_data ,
174
+ self .qdata ,
185
175
self .block_size ,
186
176
self .scale ,
187
177
self .zero_point ,
@@ -202,7 +192,10 @@ def _(func, types, args, kwargs):
202
192
args [1 ],
203
193
args [2 ] if len (args ) > 2 else None ,
204
194
)
205
- weight_tensor = weight_tensor .dequantize ()
195
+ if isinstance (input_tensor , IntxUnpackedTensor ):
196
+ input_tensor = input_tensor .dequantize ()
197
+ if isinstance (weight_tensor , IntxUnpackedTensor ):
198
+ weight_tensor = weight_tensor .dequantize ()
206
199
return torch .nn .functional .linear (input_tensor , weight_tensor , bias )
207
200
208
201
@@ -227,14 +220,14 @@ def _(func, types, args, kwargs):
227
220
# Otherwise the sliced tensor cannot be represented as a IntxUnpackedTensor
228
221
# For example, if block_size = 4, we might have:
229
222
#
230
- # int_data : i i i i | i i i i
223
+ # qdata : i i i i | i i i i
231
224
# scale: s s
232
225
#
233
- # If we set start = 2 and end = 8, then the int_data slice is:
226
+ # If we set start = 2 and end = 8, then the qdata slice is:
234
227
#
235
- # int_data_slice : i i (i i | i i i i)
228
+ # qdata_slice : i i (i i | i i i i)
236
229
#
237
- # But then the block_size for the first two int_data in the slice is 2
230
+ # But then the block_size for the first two qdata in the slice is 2
238
231
# and remaining blocks have size 4. This cannot be represented
239
232
# with the metadata we store in an IntxUnpackedTensor, which requires uniform blocking
240
233
@@ -248,15 +241,24 @@ def _(func, types, args, kwargs):
248
241
)
249
242
end_scale = end // self .block_size [dim ]
250
243
251
- int_data = aten .slice .Tensor (self .int_data , dim , start , end , step )
244
+ qdata = aten .slice .Tensor (self .qdata , dim , start , end , step )
252
245
scale = aten .slice .Tensor (self .scale , dim , start_scale , end_scale , step )
253
246
zero_point = aten .slice .Tensor (self .zero_point , dim , start_scale , end_scale , step )
254
247
255
- new = self .__class__ (
256
- int_data ,
248
+ new_block_size = []
249
+ for i in range (qdata .ndim ):
250
+ assert scale .shape [i ] == zero_point .shape [i ]
251
+ n_blocks = scale .shape [i ]
252
+ assert qdata .shape [i ] % n_blocks == 0
253
+ new_block_size .append (qdata .shape [i ] // n_blocks )
254
+ new_block_size = tuple (new_block_size )
255
+
256
+ new = IntxUnpackedTensor (
257
+ qdata ,
257
258
scale ,
258
259
zero_point ,
259
- self .bit_width ,
260
+ self .target_dtype ,
261
+ new_block_size ,
260
262
)
261
263
return return_and_correct_aliasing (func , args , kwargs , new )
262
264
0 commit comments