9
9
from typing import List , Optional
10
10
11
11
import torch
12
- from torch .utils ._python_dispatch import return_and_correct_aliasing
13
12
14
13
from torchao .utils import (
15
14
TORCH_VERSION_AT_LEAST_2_5 ,
16
15
TorchAOBaseTensor ,
17
- fill_defaults ,
18
16
)
19
17
20
18
__all__ = [
@@ -42,12 +40,12 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
42
40
int4 quantization with preshuffled packing format (for all granularities)
43
41
44
42
Tensor Attributes:
45
- _data : preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
43
+ qdata : preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
46
44
preshuffling is specific to fbgemm kernels, see Note for motivation, detailed layout doc is WIP
47
45
for bf16 activation:
48
- group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size,
46
+ group_scale: (K/group_size, N) for 2D Tensor, (B, K/group_size, N ) for 3D Tensor, where B is batch size,
49
47
dtype is the same as the original Tensor dtype
50
- group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size,
48
+ group_zero: (K/group_size, N) for 2D Tensor, (B, K/group_size, N ) for 3D Tensor, where B is batch size,
51
49
dtype is the same as the original Tensor dtype
52
50
for float8 activation:
53
51
group_scale: (K/group_size/8, 8, N) for 2D Tensor, (B, K/group_size/8, 8, N) for 3D Tensor
@@ -57,9 +55,6 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
57
55
58
56
Non-Tensor Attributes:
59
57
block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
60
- shape_multiplier: is the multipler from _data to the real weight, since
61
- we pack the weight for int4, for example, when we pack the last dimension for
62
- a 2D tensor, the shape_multiplier will be [1, 2]
63
58
shape: shape of the original Tensor
64
59
65
60
Note on Details for preshuffle for fbgemm kernel:
@@ -80,104 +75,48 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
80
75
requires symmetric quantization
81
76
"""
82
77
83
- tensor_data_attrs = ["_data" , "group_scale" ]
84
- tensor_attributes = ["block_size" , "shape_multiplier" , "shape" ]
78
+ tensor_data_names = ["qdata" , "group_scale" ]
79
+ optional_tensor_data_names = ["group_zero" , "row_scale" ]
80
+ tensor_attribute_names = ["block_size" , "shape" ]
85
81
86
82
def __new__ (
87
83
cls ,
88
- _data ,
84
+ qdata ,
89
85
group_scale ,
90
86
group_zero ,
91
87
row_scale ,
92
88
block_size ,
93
- shape_multiplier ,
94
89
shape ,
95
90
):
96
91
kwargs = {}
97
- kwargs ["device" ] = _data .device
92
+ kwargs ["device" ] = qdata .device
98
93
kwargs ["dtype" ] = group_scale .dtype
99
94
kwargs ["requires_grad" ] = False
100
95
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
101
96
102
97
def __init__ (
103
98
self ,
104
- _data : torch .Tensor ,
99
+ qdata : torch .Tensor ,
105
100
group_scale : torch .Tensor ,
106
101
group_zero : Optional [torch .Tensor ],
107
102
row_scale : Optional [torch .Tensor ],
108
103
block_size : List [int ],
109
- shape_multiplier : List [int ],
110
104
shape : List [int ],
111
105
):
112
106
# one and only one of group_scale and group_zero should be None
113
107
assert group_zero is None or row_scale is None
114
108
assert not (group_zero is not None and row_scale is not None )
115
- self ._data = _data
109
+ self .qdata = qdata
116
110
self .group_scale = group_scale
117
111
self .group_zero = group_zero
118
112
self .row_scale = row_scale
119
- self .shape_multiplier = shape_multiplier
120
113
self .block_size = block_size
121
114
122
- def __tensor_flatten__ (self ):
123
- if getattr (self , "group_zero" ) is None :
124
- assert getattr (self , "row_scale" ) is not None
125
- return self .tensor_data_attrs + ["row_scale" ], [
126
- getattr (self , attr ) for attr in self .tensor_attributes
127
- ]
128
- else :
129
- return self .tensor_data_attrs + ["group_zero" ], [
130
- getattr (self , attr ) for attr in self .tensor_attributes
131
- ]
132
-
133
- @classmethod
134
- def __tensor_unflatten__ (
135
- cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
136
- ):
137
- tensors = [tensor_data_dict [name ] for name in cls .tensor_data_attrs ]
138
- tensors .append (tensor_data_dict .get ("group_zero" , None ))
139
- tensors .append (tensor_data_dict .get ("row_scale" , None ))
140
- return cls (
141
- * tensors ,
142
- * tensor_attributes ,
143
- )
144
-
145
- def _apply_fn_to_data (self , fn ):
146
- tensors = [fn (getattr (self , name )) for name in self .tensor_data_attrs ]
147
- t1 = getattr (self , "group_zero" )
148
- tensors .append (fn (t1 ) if t1 is not None else None )
149
- t2 = getattr (self , "row_scale" )
150
- tensors .append (fn (t2 ) if t2 is not None else None )
151
- return self .__class__ (
152
- * tensors ,
153
- * [getattr (self , attr ) for attr in self .tensor_attributes ],
154
- )
155
-
156
- def __repr__ (self ):
157
- return (
158
- f"{ self .__class__ .__name__ } (weight={ self ._data } , block_size={ self .block_size } , "
159
- f"shape_multiplier={ self .shape_multiplier } , shape={ self .shape } , device={ self .device } , dtype={ self .dtype } , "
160
- f"requires_grad={ self .requires_grad } )"
161
- )
162
-
163
115
def _quantization_type (self ):
164
116
return f"shape={ self .shape } , block_size={ self .block_size } , device={ self .device } "
165
117
166
- def to (self , * args , ** kwargs ):
167
- kwargs = self ._get_to_kwargs (* args , ** kwargs )
168
- device = kwargs .pop ("device" )
169
- return self .__class__ (
170
- self ._data .to (device ),
171
- self .group_scale .to (device ),
172
- self .group_zero .to (device ) if self .group_zero is not None else None ,
173
- self .row_scale .to (device ) if self .row_scale is not None else None ,
174
- self .block_size ,
175
- self .shape_multiplier ,
176
- self .shape ,
177
- )
178
-
179
118
@classmethod
180
- def from_float (
119
+ def from_hp (
181
120
cls ,
182
121
w : torch .Tensor ,
183
122
block_size : List [int ],
@@ -237,17 +176,12 @@ def from_float(
237
176
group_zero = None
238
177
row_scale = group_zero_or_row_scale
239
178
240
- shape_multiplier = [1 ] * wq .ndim
241
- shape_multiplier [- 1 ] = 2
242
-
243
- del w
244
179
return Int4PreshuffledTensor (
245
- _data = wq ,
180
+ qdata = wq ,
246
181
group_scale = group_scale ,
247
182
group_zero = group_zero ,
248
183
row_scale = row_scale ,
249
184
block_size = block_size ,
250
- shape_multiplier = shape_multiplier ,
251
185
shape = original_shape ,
252
186
)
253
187
@@ -265,15 +199,16 @@ def _(func, types, args, kwargs):
265
199
orig_input_size = input_tensor .size ()
266
200
orig_out_features = weight_tensor .shape [- 2 ]
267
201
268
- wq = weight_tensor ._data .contiguous ()
202
+ wq = weight_tensor .qdata .contiguous ()
269
203
group_scale = weight_tensor .group_scale .contiguous ()
270
- # bf16 activation
271
204
if weight_tensor .group_zero is not None :
205
+ # bf16 activation
272
206
group_zero = weight_tensor .group_zero .contiguous ()
273
207
res = torch .ops .fbgemm .bf16i4bf16_shuffled (
274
208
input_tensor , wq , group_scale , group_zero
275
209
)
276
210
else :
211
+ # dynamically quantizes activation to fp8
277
212
assert weight_tensor .row_scale is not None
278
213
row_scale = weight_tensor .row_scale .contiguous ()
279
214
xq , x_scale = quantize_fp8_row (input_tensor )
@@ -295,16 +230,17 @@ def _(func, types, args, kwargs):
295
230
)
296
231
orig_input_size = input_tensor .size ()
297
232
orig_out_features = weight_tensor .shape [- 2 ]
298
- assert weight_tensor .shape_multiplier [- 1 ] == 2
299
233
300
- wq = weight_tensor ._data .contiguous ()
234
+ wq = weight_tensor .qdata .contiguous ()
301
235
group_scale = weight_tensor .group_scale .contiguous ()
302
236
if weight_tensor .group_zero is not None :
237
+ # bfloat16 activation
303
238
group_zero = weight_tensor .group_zero .contiguous ()
304
239
res = torch .ops .fbgemm .bf16i4bf16_shuffled_batched (
305
240
input_tensor , wq , group_scale , group_zero
306
241
)
307
242
else :
243
+ # dynamically quantizes activation to fp8
308
244
assert weight_tensor .row_scale is not None
309
245
row_scale = weight_tensor .row_scale .contiguous ()
310
246
xq , x_scale = quantize_fp8_row (input_tensor )
@@ -322,125 +258,6 @@ def _(func, types, args, kwargs):
322
258
return res
323
259
324
260
325
- @implements ([aten .detach .default , aten .alias .default ])
326
- def _ (func , types , args , kwargs ):
327
- return return_and_correct_aliasing (
328
- func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
329
- )
330
-
331
-
332
- @implements (aten .clone .default )
333
- def _ (func , types , args , kwargs ):
334
- return return_and_correct_aliasing (
335
- func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
336
- )
337
-
338
-
339
- def _same_metadata (self : "Int4PreshuffledTensor" , src : "Int4PreshuffledTensor" ) -> bool :
340
- return (
341
- isinstance (self , Int4PreshuffledTensor )
342
- and isinstance (src , Int4PreshuffledTensor )
343
- and self .shape == src .shape
344
- and self ._data .shape == src ._data .shape
345
- and self .group_scale .shape == src .group_scale .shape
346
- and (
347
- self .group_zero .shape == src .group_zero .shape
348
- if self .group_zero is not None
349
- else src .group_zero is None
350
- )
351
- and (
352
- self .row_scale .shape == src .row_scale .shape
353
- if self .row_scale is not None
354
- else src .row_scale is None
355
- )
356
- and self .block_size == src .block_size
357
- and self .shape_multiplier == src .shape_multiplier
358
- )
359
-
360
-
361
- @implements (aten .copy_ .default )
362
- def _ (func , types , args , kwargs ):
363
- self = args [0 ]
364
- src = args [1 ]
365
- if _same_metadata (self , src ):
366
- self_tensors = self .__tensor_flatten__ ()[0 ]
367
- for tensor_name in self_tensors :
368
- getattr (self , tensor_name ).copy_ (getattr (src , tensor_name ))
369
- return
370
- raise ValueError (
371
- f"Not supported args for copy_ due to metadata mismatch: { args [0 ], args [1 ]} "
372
- )
373
-
374
-
375
- @implements (aten .cat .default )
376
- def _ (func , types , args , kwargs ):
377
- tensors , dim = fill_defaults (args , 2 , [[], 0 ])
378
- tensor_0 = tensors [0 ]
379
- if dim < 0 :
380
- dim = dim + tensor_0 .ndim
381
-
382
- for i in range (1 , len (tensors )):
383
- assert tensor_0 ._data .ndim == tensors [i ]._data .ndim
384
- assert tensor_0 .group_scale .ndim == tensors [i ].group_scale .ndim
385
- assert tensor_0 .group_zero .ndim == tensors [i ].group_zero .ndim
386
- assert tensor_0 .block_size == tensors [i ].block_size
387
- assert tensor_0 .shape_multiplier == tensors [i ].shape_multiplier
388
-
389
- _data = [t ._data for t in tensors ]
390
- group_scale = [t .group_scale for t in tensors ]
391
- group_zero = [t .group_zero for t in tensors ]
392
-
393
- # with group wise quantization, dimension of group_scale, _data and
394
- # origianl shape will be the same, so original dim argument applies
395
- # to both _data and group_scale
396
- cat_data = aten .cat .default (_data , dim )
397
- if cat_data .ndim == 2 :
398
- sz_dim = 1 - dim
399
- else :
400
- sz_dim = dim
401
-
402
- cat_group_scale = aten .cat .default (group_scale , sz_dim )
403
- cat_group_zero = aten .cat .default (group_zero , sz_dim )
404
- new_shape = list (cat_data .shape )
405
- for i in range (len (tensor_0 .shape_multiplier )):
406
- new_shape [i ] *= tensor_0 .shape_multiplier [i ]
407
- new_shape = tuple (new_shape )
408
- new = tensor_0 .__class__ (
409
- cat_data ,
410
- cat_group_scale ,
411
- cat_group_zero ,
412
- block_size = tensor_0 .block_size ,
413
- shape_multiplier = tensor_0 .shape_multiplier ,
414
- shape = new_shape ,
415
- )
416
- return return_and_correct_aliasing (func , args , kwargs , new )
417
-
418
-
419
- @implements (aten .transpose .int )
420
- def _ (func , types , args , kwargs ):
421
- self , dim0 , dim1 = args
422
- _data = self ._data .transpose (dim0 , dim1 ).contiguous ()
423
- shape_multiplier = self .shape_multiplier .copy ()
424
- shape_multiplier [dim0 ], shape_multiplier [dim1 ] = (
425
- shape_multiplier [dim1 ],
426
- shape_multiplier [dim0 ],
427
- )
428
-
429
- tensor_shape = list (_data .shape )
430
- for i in range (len (shape_multiplier )):
431
- tensor_shape [i ] *= shape_multiplier [i ]
432
- tensor_shape = tuple (tensor_shape )
433
- new = self .__class__ (
434
- _data ,
435
- self .group_scale ,
436
- self .group_zero ,
437
- self .block_size ,
438
- shape_multiplier ,
439
- tensor_shape ,
440
- )
441
- return return_and_correct_aliasing (func , args , kwargs , new )
442
-
443
-
444
261
Int4PreshuffledTensor .__module__ = "torchao.quantization"
445
262
446
263
if TORCH_VERSION_AT_LEAST_2_5 :
0 commit comments