Skip to content

Update Int4PreshuffledTensor to align with implementation details of the Float8Tensor #2738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
if config.VERSION == 2:
block_size = list(block_size)
if packing_format == PackingFormat.PRESHUFFLED:
new_weight = Int4PreshuffledTensor.from_float(
new_weight = Int4PreshuffledTensor.from_hp(
weight,
block_size,
activation_dtype=torch.bfloat16,
Expand Down Expand Up @@ -1281,7 +1281,7 @@ def _float8_activation_int4_weight_transform(
)
weight = module.weight
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
new_weight = Int4PreshuffledTensor.from_float(
new_weight = Int4PreshuffledTensor.from_hp(
module.weight,
block_size,
activation_dtype=torch.float8_e4m3fn,
Expand Down Expand Up @@ -2207,7 +2207,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
and (config.output_dtype == torch.bfloat16)
):
if config.preshuffle:
weight = Int4PreshuffledTensor.from_float(
weight = Int4PreshuffledTensor.from_hp(
module.weight,
config.block_size,
activation_dtype=torch.bfloat16,
Expand All @@ -2226,7 +2226,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
and (config.output_dtype == torch.bfloat16)
):
if config.preshuffle:
weight = Int4PreshuffledTensor.from_float(
weight = Int4PreshuffledTensor.from_hp(
module.weight,
config.block_size,
activation_dtype=torch.float8_e4m3fn,
Expand Down
219 changes: 18 additions & 201 deletions torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from typing import List, Optional

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
fill_defaults,
)

__all__ = [
Expand Down Expand Up @@ -42,12 +40,12 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
int4 quantization with preshuffled packing format (for all granularities)
Tensor Attributes:
_data: preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
qdata: preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
preshuffling is specific to fbgemm kernels, see Note for motivation, detailed layout doc is WIP
for bf16 activation:
group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size,
group_scale: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size,
dtype is the same as the original Tensor dtype
group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size,
group_zero: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size,
dtype is the same as the original Tensor dtype
for float8 activation:
group_scale: (K/group_size/8, 8, N) for 2D Tensor, (B, K/group_size/8, 8, N) for 3D Tensor
Expand All @@ -57,9 +55,6 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
Non-Tensor Attributes:
block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
shape_multiplier: is the multipler from _data to the real weight, since
we pack the weight for int4, for example, when we pack the last dimension for
a 2D tensor, the shape_multiplier will be [1, 2]
shape: shape of the original Tensor
Note on Details for preshuffle for fbgemm kernel:
Expand All @@ -80,104 +75,48 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
requires symmetric quantization
"""

tensor_data_attrs = ["_data", "group_scale"]
tensor_attributes = ["block_size", "shape_multiplier", "shape"]
tensor_data_names = ["qdata", "group_scale"]
optional_tensor_data_names = ["group_zero", "row_scale"]
tensor_attribute_names = ["block_size", "shape"]

def __new__(
cls,
_data,
qdata,
group_scale,
group_zero,
row_scale,
block_size,
shape_multiplier,
shape,
):
kwargs = {}
kwargs["device"] = _data.device
kwargs["device"] = qdata.device
kwargs["dtype"] = group_scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
_data: torch.Tensor,
qdata: torch.Tensor,
group_scale: torch.Tensor,
group_zero: Optional[torch.Tensor],
row_scale: Optional[torch.Tensor],
block_size: List[int],
shape_multiplier: List[int],
shape: List[int],
):
# one and only one of group_scale and group_zero should be None
assert group_zero is None or row_scale is None
assert not (group_zero is not None and row_scale is not None)
self._data = _data
self.qdata = qdata
self.group_scale = group_scale
self.group_zero = group_zero
self.row_scale = row_scale
self.shape_multiplier = shape_multiplier
self.block_size = block_size

def __tensor_flatten__(self):
if getattr(self, "group_zero") is None:
assert getattr(self, "row_scale") is not None
return self.tensor_data_attrs + ["row_scale"], [
getattr(self, attr) for attr in self.tensor_attributes
]
else:
return self.tensor_data_attrs + ["group_zero"], [
getattr(self, attr) for attr in self.tensor_attributes
]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
tensors = [tensor_data_dict[name] for name in cls.tensor_data_attrs]
tensors.append(tensor_data_dict.get("group_zero", None))
tensors.append(tensor_data_dict.get("row_scale", None))
return cls(
*tensors,
*tensor_attributes,
)

def _apply_fn_to_data(self, fn):
tensors = [fn(getattr(self, name)) for name in self.tensor_data_attrs]
t1 = getattr(self, "group_zero")
tensors.append(fn(t1) if t1 is not None else None)
t2 = getattr(self, "row_scale")
tensors.append(fn(t2) if t2 is not None else None)
return self.__class__(
*tensors,
*[getattr(self, attr) for attr in self.tensor_attributes],
)

def __repr__(self):
return (
f"{self.__class__.__name__}(weight={self._data}, block_size={self.block_size}, "
f"shape_multiplier={self.shape_multiplier}, shape={self.shape}, device={self.device}, dtype={self.dtype}, "
f"requires_grad={self.requires_grad})"
)

def _quantization_type(self):
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs.pop("device")
return self.__class__(
self._data.to(device),
self.group_scale.to(device),
self.group_zero.to(device) if self.group_zero is not None else None,
self.row_scale.to(device) if self.row_scale is not None else None,
self.block_size,
self.shape_multiplier,
self.shape,
)

@classmethod
def from_float(
def from_hp(
cls,
w: torch.Tensor,
block_size: List[int],
Expand Down Expand Up @@ -237,17 +176,12 @@ def from_float(
group_zero = None
row_scale = group_zero_or_row_scale

shape_multiplier = [1] * wq.ndim
shape_multiplier[-1] = 2

del w
return Int4PreshuffledTensor(
_data=wq,
qdata=wq,
group_scale=group_scale,
group_zero=group_zero,
row_scale=row_scale,
block_size=block_size,
shape_multiplier=shape_multiplier,
shape=original_shape,
)

Expand All @@ -265,15 +199,16 @@ def _(func, types, args, kwargs):
orig_input_size = input_tensor.size()
orig_out_features = weight_tensor.shape[-2]

wq = weight_tensor._data.contiguous()
wq = weight_tensor.qdata.contiguous()
group_scale = weight_tensor.group_scale.contiguous()
# bf16 activation
if weight_tensor.group_zero is not None:
# bf16 activation
group_zero = weight_tensor.group_zero.contiguous()
res = torch.ops.fbgemm.bf16i4bf16_shuffled(
input_tensor, wq, group_scale, group_zero
)
else:
# dynamically quantizes activation to fp8
assert weight_tensor.row_scale is not None
row_scale = weight_tensor.row_scale.contiguous()
xq, x_scale = quantize_fp8_row(input_tensor)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw @vkuzo, we didn't use _choose_quant_func_and_quantize_tensor and float8 quant args yet, but we can add that if there are multiple choices of float8 activation quant in the future. please let me know if you have different thoughts here.

Expand All @@ -295,16 +230,17 @@ def _(func, types, args, kwargs):
)
orig_input_size = input_tensor.size()
orig_out_features = weight_tensor.shape[-2]
assert weight_tensor.shape_multiplier[-1] == 2

wq = weight_tensor._data.contiguous()
wq = weight_tensor.qdata.contiguous()
group_scale = weight_tensor.group_scale.contiguous()
if weight_tensor.group_zero is not None:
# bfloat16 activation
group_zero = weight_tensor.group_zero.contiguous()
res = torch.ops.fbgemm.bf16i4bf16_shuffled_batched(
input_tensor, wq, group_scale, group_zero
)
else:
# dynamically quantizes activation to fp8
assert weight_tensor.row_scale is not None
row_scale = weight_tensor.row_scale.contiguous()
xq, x_scale = quantize_fp8_row(input_tensor)
Expand All @@ -322,125 +258,6 @@ def _(func, types, args, kwargs):
return res


@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


def _same_metadata(self: "Int4PreshuffledTensor", src: "Int4PreshuffledTensor") -> bool:
return (
isinstance(self, Int4PreshuffledTensor)
and isinstance(src, Int4PreshuffledTensor)
and self.shape == src.shape
and self._data.shape == src._data.shape
and self.group_scale.shape == src.group_scale.shape
and (
self.group_zero.shape == src.group_zero.shape
if self.group_zero is not None
else src.group_zero is None
)
and (
self.row_scale.shape == src.row_scale.shape
if self.row_scale is not None
else src.row_scale is None
)
and self.block_size == src.block_size
and self.shape_multiplier == src.shape_multiplier
)


@implements(aten.copy_.default)
def _(func, types, args, kwargs):
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}"
)


@implements(aten.cat.default)
def _(func, types, args, kwargs):
tensors, dim = fill_defaults(args, 2, [[], 0])
tensor_0 = tensors[0]
if dim < 0:
dim = dim + tensor_0.ndim

for i in range(1, len(tensors)):
assert tensor_0._data.ndim == tensors[i]._data.ndim
assert tensor_0.group_scale.ndim == tensors[i].group_scale.ndim
assert tensor_0.group_zero.ndim == tensors[i].group_zero.ndim
assert tensor_0.block_size == tensors[i].block_size
assert tensor_0.shape_multiplier == tensors[i].shape_multiplier

_data = [t._data for t in tensors]
group_scale = [t.group_scale for t in tensors]
group_zero = [t.group_zero for t in tensors]

# with group wise quantization, dimension of group_scale, _data and
# origianl shape will be the same, so original dim argument applies
# to both _data and group_scale
cat_data = aten.cat.default(_data, dim)
if cat_data.ndim == 2:
sz_dim = 1 - dim
else:
sz_dim = dim

cat_group_scale = aten.cat.default(group_scale, sz_dim)
cat_group_zero = aten.cat.default(group_zero, sz_dim)
new_shape = list(cat_data.shape)
for i in range(len(tensor_0.shape_multiplier)):
new_shape[i] *= tensor_0.shape_multiplier[i]
new_shape = tuple(new_shape)
new = tensor_0.__class__(
cat_data,
cat_group_scale,
cat_group_zero,
block_size=tensor_0.block_size,
shape_multiplier=tensor_0.shape_multiplier,
shape=new_shape,
)
return return_and_correct_aliasing(func, args, kwargs, new)


@implements(aten.transpose.int)
def _(func, types, args, kwargs):
self, dim0, dim1 = args
_data = self._data.transpose(dim0, dim1).contiguous()
shape_multiplier = self.shape_multiplier.copy()
shape_multiplier[dim0], shape_multiplier[dim1] = (
shape_multiplier[dim1],
shape_multiplier[dim0],
)

tensor_shape = list(_data.shape)
for i in range(len(shape_multiplier)):
tensor_shape[i] *= shape_multiplier[i]
tensor_shape = tuple(tensor_shape)
new = self.__class__(
_data,
self.group_scale,
self.group_zero,
self.block_size,
shape_multiplier,
tensor_shape,
)
return return_and_correct_aliasing(func, args, kwargs, new)


Int4PreshuffledTensor.__module__ = "torchao.quantization"

if TORCH_VERSION_AT_LEAST_2_5:
Expand Down
Loading