Skip to content

Commit 584723f

Browse files
committed
Update Int4PreshuffledTensor to align with implementation details of the Float8Tensor
Summary: similar to #2687, we updated Int4PreshuffledTensor to align the implementation details, also used TorchAOBaseTensor to simplify some of the implementations Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2738, branch: jerryzh168/stack/20
1 parent 6dfe202 commit 584723f

File tree

2 files changed

+20
-204
lines changed

2 files changed

+20
-204
lines changed

torchao/quantization/quant_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
11621162
if config.VERSION == 2:
11631163
block_size = list(block_size)
11641164
if packing_format == PackingFormat.PRESHUFFLED:
1165-
new_weight = Int4PreshuffledTensor.from_float(
1165+
new_weight = Int4PreshuffledTensor.from_hp(
11661166
weight,
11671167
block_size,
11681168
activation_dtype=torch.bfloat16,
@@ -1281,7 +1281,7 @@ def _float8_activation_int4_weight_transform(
12811281
)
12821282
weight = module.weight
12831283
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1284-
new_weight = Int4PreshuffledTensor.from_float(
1284+
new_weight = Int4PreshuffledTensor.from_hp(
12851285
module.weight,
12861286
block_size,
12871287
activation_dtype=torch.float8_e4m3fn,
@@ -2207,7 +2207,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
22072207
and (config.output_dtype == torch.bfloat16)
22082208
):
22092209
if config.preshuffle:
2210-
weight = Int4PreshuffledTensor.from_float(
2210+
weight = Int4PreshuffledTensor.from_hp(
22112211
module.weight,
22122212
config.block_size,
22132213
activation_dtype=torch.bfloat16,
@@ -2226,7 +2226,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
22262226
and (config.output_dtype == torch.bfloat16)
22272227
):
22282228
if config.preshuffle:
2229-
weight = Int4PreshuffledTensor.from_float(
2229+
weight = Int4PreshuffledTensor.from_hp(
22302230
module.weight,
22312231
config.block_size,
22322232
activation_dtype=torch.float8_e4m3fn,

torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py

Lines changed: 16 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
from typing import List, Optional
1010

1111
import torch
12-
from torch.utils._python_dispatch import return_and_correct_aliasing
1312

1413
from torchao.utils import (
1514
TORCH_VERSION_AT_LEAST_2_5,
1615
TorchAOBaseTensor,
17-
fill_defaults,
1816
)
1917

2018
__all__ = [
@@ -42,12 +40,12 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
4240
int4 quantization with preshuffled packing format (for all granularities)
4341
4442
Tensor Attributes:
45-
_data: preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
43+
qdata: preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
4644
preshuffling is specific to fbgemm kernels, see Note for motivation, detailed layout doc is WIP
4745
for bf16 activation:
48-
group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size,
46+
group_scale: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size,
4947
dtype is the same as the original Tensor dtype
50-
group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size,
48+
group_zero: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size,
5149
dtype is the same as the original Tensor dtype
5250
for float8 activation:
5351
group_scale: (K/group_size/8, 8, N) for 2D Tensor, (B, K/group_size/8, 8, N) for 3D Tensor
@@ -57,9 +55,6 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
5755
5856
Non-Tensor Attributes:
5957
block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
60-
shape_multiplier: is the multipler from _data to the real weight, since
61-
we pack the weight for int4, for example, when we pack the last dimension for
62-
a 2D tensor, the shape_multiplier will be [1, 2]
6358
shape: shape of the original Tensor
6459
6560
Note on Details for preshuffle for fbgemm kernel:
@@ -80,104 +75,48 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
8075
requires symmetric quantization
8176
"""
8277

83-
tensor_data_attrs = ["_data", "group_scale"]
84-
tensor_attributes = ["block_size", "shape_multiplier", "shape"]
78+
tensor_data_names = ["qdata", "group_scale"]
79+
optional_tensor_data_names = ["group_zero", "row_scale"]
80+
tensor_attribute_names = ["block_size", "shape"]
8581

8682
def __new__(
8783
cls,
88-
_data,
84+
qdata,
8985
group_scale,
9086
group_zero,
9187
row_scale,
9288
block_size,
93-
shape_multiplier,
9489
shape,
9590
):
9691
kwargs = {}
97-
kwargs["device"] = _data.device
92+
kwargs["device"] = qdata.device
9893
kwargs["dtype"] = group_scale.dtype
9994
kwargs["requires_grad"] = False
10095
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
10196

10297
def __init__(
10398
self,
104-
_data: torch.Tensor,
99+
qdata: torch.Tensor,
105100
group_scale: torch.Tensor,
106101
group_zero: Optional[torch.Tensor],
107102
row_scale: Optional[torch.Tensor],
108103
block_size: List[int],
109-
shape_multiplier: List[int],
110104
shape: List[int],
111105
):
112106
# one and only one of group_scale and group_zero should be None
113107
assert group_zero is None or row_scale is None
114108
assert not (group_zero is not None and row_scale is not None)
115-
self._data = _data
109+
self.qdata = qdata
116110
self.group_scale = group_scale
117111
self.group_zero = group_zero
118112
self.row_scale = row_scale
119-
self.shape_multiplier = shape_multiplier
120113
self.block_size = block_size
121114

122-
def __tensor_flatten__(self):
123-
if getattr(self, "group_zero") is None:
124-
assert getattr(self, "row_scale") is not None
125-
return self.tensor_data_attrs + ["row_scale"], [
126-
getattr(self, attr) for attr in self.tensor_attributes
127-
]
128-
else:
129-
return self.tensor_data_attrs + ["group_zero"], [
130-
getattr(self, attr) for attr in self.tensor_attributes
131-
]
132-
133-
@classmethod
134-
def __tensor_unflatten__(
135-
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
136-
):
137-
tensors = [tensor_data_dict[name] for name in cls.tensor_data_attrs]
138-
tensors.append(tensor_data_dict.get("group_zero", None))
139-
tensors.append(tensor_data_dict.get("row_scale", None))
140-
return cls(
141-
*tensors,
142-
*tensor_attributes,
143-
)
144-
145-
def _apply_fn_to_data(self, fn):
146-
tensors = [fn(getattr(self, name)) for name in self.tensor_data_attrs]
147-
t1 = getattr(self, "group_zero")
148-
tensors.append(fn(t1) if t1 is not None else None)
149-
t2 = getattr(self, "row_scale")
150-
tensors.append(fn(t2) if t2 is not None else None)
151-
return self.__class__(
152-
*tensors,
153-
*[getattr(self, attr) for attr in self.tensor_attributes],
154-
)
155-
156-
def __repr__(self):
157-
return (
158-
f"{self.__class__.__name__}(weight={self._data}, block_size={self.block_size}, "
159-
f"shape_multiplier={self.shape_multiplier}, shape={self.shape}, device={self.device}, dtype={self.dtype}, "
160-
f"requires_grad={self.requires_grad})"
161-
)
162-
163115
def _quantization_type(self):
164116
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
165117

166-
def to(self, *args, **kwargs):
167-
kwargs = self._get_to_kwargs(*args, **kwargs)
168-
device = kwargs.pop("device")
169-
return self.__class__(
170-
self._data.to(device),
171-
self.group_scale.to(device),
172-
self.group_zero.to(device) if self.group_zero is not None else None,
173-
self.row_scale.to(device) if self.row_scale is not None else None,
174-
self.block_size,
175-
self.shape_multiplier,
176-
self.shape,
177-
)
178-
179118
@classmethod
180-
def from_float(
119+
def from_hp(
181120
cls,
182121
w: torch.Tensor,
183122
block_size: List[int],
@@ -237,17 +176,12 @@ def from_float(
237176
group_zero = None
238177
row_scale = group_zero_or_row_scale
239178

240-
shape_multiplier = [1] * wq.ndim
241-
shape_multiplier[-1] = 2
242-
243-
del w
244179
return Int4PreshuffledTensor(
245-
_data=wq,
180+
qdata=wq,
246181
group_scale=group_scale,
247182
group_zero=group_zero,
248183
row_scale=row_scale,
249184
block_size=block_size,
250-
shape_multiplier=shape_multiplier,
251185
shape=original_shape,
252186
)
253187

@@ -265,7 +199,7 @@ def _(func, types, args, kwargs):
265199
orig_input_size = input_tensor.size()
266200
orig_out_features = weight_tensor.shape[-2]
267201

268-
wq = weight_tensor._data.contiguous()
202+
wq = weight_tensor.qdata.contiguous()
269203
group_scale = weight_tensor.group_scale.contiguous()
270204
# bf16 activation
271205
if weight_tensor.group_zero is not None:
@@ -295,16 +229,17 @@ def _(func, types, args, kwargs):
295229
)
296230
orig_input_size = input_tensor.size()
297231
orig_out_features = weight_tensor.shape[-2]
298-
assert weight_tensor.shape_multiplier[-1] == 2
299232

300-
wq = weight_tensor._data.contiguous()
233+
wq = weight_tensor.qdata.contiguous()
301234
group_scale = weight_tensor.group_scale.contiguous()
302235
if weight_tensor.group_zero is not None:
236+
# bfloat16 activation
303237
group_zero = weight_tensor.group_zero.contiguous()
304238
res = torch.ops.fbgemm.bf16i4bf16_shuffled_batched(
305239
input_tensor, wq, group_scale, group_zero
306240
)
307241
else:
242+
# fp8 activation
308243
assert weight_tensor.row_scale is not None
309244
row_scale = weight_tensor.row_scale.contiguous()
310245
xq, x_scale = quantize_fp8_row(input_tensor)
@@ -322,125 +257,6 @@ def _(func, types, args, kwargs):
322257
return res
323258

324259

325-
@implements([aten.detach.default, aten.alias.default])
326-
def _(func, types, args, kwargs):
327-
return return_and_correct_aliasing(
328-
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
329-
)
330-
331-
332-
@implements(aten.clone.default)
333-
def _(func, types, args, kwargs):
334-
return return_and_correct_aliasing(
335-
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
336-
)
337-
338-
339-
def _same_metadata(self: "Int4PreshuffledTensor", src: "Int4PreshuffledTensor") -> bool:
340-
return (
341-
isinstance(self, Int4PreshuffledTensor)
342-
and isinstance(src, Int4PreshuffledTensor)
343-
and self.shape == src.shape
344-
and self._data.shape == src._data.shape
345-
and self.group_scale.shape == src.group_scale.shape
346-
and (
347-
self.group_zero.shape == src.group_zero.shape
348-
if self.group_zero is not None
349-
else src.group_zero is None
350-
)
351-
and (
352-
self.row_scale.shape == src.row_scale.shape
353-
if self.row_scale is not None
354-
else src.row_scale is None
355-
)
356-
and self.block_size == src.block_size
357-
and self.shape_multiplier == src.shape_multiplier
358-
)
359-
360-
361-
@implements(aten.copy_.default)
362-
def _(func, types, args, kwargs):
363-
self = args[0]
364-
src = args[1]
365-
if _same_metadata(self, src):
366-
self_tensors = self.__tensor_flatten__()[0]
367-
for tensor_name in self_tensors:
368-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
369-
return
370-
raise ValueError(
371-
f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}"
372-
)
373-
374-
375-
@implements(aten.cat.default)
376-
def _(func, types, args, kwargs):
377-
tensors, dim = fill_defaults(args, 2, [[], 0])
378-
tensor_0 = tensors[0]
379-
if dim < 0:
380-
dim = dim + tensor_0.ndim
381-
382-
for i in range(1, len(tensors)):
383-
assert tensor_0._data.ndim == tensors[i]._data.ndim
384-
assert tensor_0.group_scale.ndim == tensors[i].group_scale.ndim
385-
assert tensor_0.group_zero.ndim == tensors[i].group_zero.ndim
386-
assert tensor_0.block_size == tensors[i].block_size
387-
assert tensor_0.shape_multiplier == tensors[i].shape_multiplier
388-
389-
_data = [t._data for t in tensors]
390-
group_scale = [t.group_scale for t in tensors]
391-
group_zero = [t.group_zero for t in tensors]
392-
393-
# with group wise quantization, dimension of group_scale, _data and
394-
# origianl shape will be the same, so original dim argument applies
395-
# to both _data and group_scale
396-
cat_data = aten.cat.default(_data, dim)
397-
if cat_data.ndim == 2:
398-
sz_dim = 1 - dim
399-
else:
400-
sz_dim = dim
401-
402-
cat_group_scale = aten.cat.default(group_scale, sz_dim)
403-
cat_group_zero = aten.cat.default(group_zero, sz_dim)
404-
new_shape = list(cat_data.shape)
405-
for i in range(len(tensor_0.shape_multiplier)):
406-
new_shape[i] *= tensor_0.shape_multiplier[i]
407-
new_shape = tuple(new_shape)
408-
new = tensor_0.__class__(
409-
cat_data,
410-
cat_group_scale,
411-
cat_group_zero,
412-
block_size=tensor_0.block_size,
413-
shape_multiplier=tensor_0.shape_multiplier,
414-
shape=new_shape,
415-
)
416-
return return_and_correct_aliasing(func, args, kwargs, new)
417-
418-
419-
@implements(aten.transpose.int)
420-
def _(func, types, args, kwargs):
421-
self, dim0, dim1 = args
422-
_data = self._data.transpose(dim0, dim1).contiguous()
423-
shape_multiplier = self.shape_multiplier.copy()
424-
shape_multiplier[dim0], shape_multiplier[dim1] = (
425-
shape_multiplier[dim1],
426-
shape_multiplier[dim0],
427-
)
428-
429-
tensor_shape = list(_data.shape)
430-
for i in range(len(shape_multiplier)):
431-
tensor_shape[i] *= shape_multiplier[i]
432-
tensor_shape = tuple(tensor_shape)
433-
new = self.__class__(
434-
_data,
435-
self.group_scale,
436-
self.group_zero,
437-
self.block_size,
438-
shape_multiplier,
439-
tensor_shape,
440-
)
441-
return return_and_correct_aliasing(func, args, kwargs, new)
442-
443-
444260
Int4PreshuffledTensor.__module__ = "torchao.quantization"
445261

446262
if TORCH_VERSION_AT_LEAST_2_5:

0 commit comments

Comments
 (0)