diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a07297d74a..446311b9b8 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1162,7 +1162,7 @@ def _int4_weight_only_quantize_tensor(weight, config): if config.VERSION == 2: block_size = list(block_size) if packing_format == PackingFormat.PRESHUFFLED: - new_weight = Int4PreshuffledTensor.from_float( + new_weight = Int4PreshuffledTensor.from_hp( weight, block_size, activation_dtype=torch.bfloat16, @@ -1281,7 +1281,7 @@ def _float8_activation_int4_weight_transform( ) weight = module.weight block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) - new_weight = Int4PreshuffledTensor.from_float( + new_weight = Int4PreshuffledTensor.from_hp( module.weight, block_size, activation_dtype=torch.float8_e4m3fn, @@ -2207,7 +2207,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: and (config.output_dtype == torch.bfloat16) ): if config.preshuffle: - weight = Int4PreshuffledTensor.from_float( + weight = Int4PreshuffledTensor.from_hp( module.weight, config.block_size, activation_dtype=torch.bfloat16, @@ -2226,7 +2226,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: and (config.output_dtype == torch.bfloat16) ): if config.preshuffle: - weight = Int4PreshuffledTensor.from_float( + weight = Int4PreshuffledTensor.from_hp( module.weight, config.block_size, activation_dtype=torch.float8_e4m3fn, diff --git a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py index bd894ceea0..16595f370e 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -9,12 +9,10 @@ from typing import List, Optional import torch -from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, - fill_defaults, ) __all__ = [ @@ -42,12 +40,12 @@ class Int4PreshuffledTensor(TorchAOBaseTensor): int4 quantization with preshuffled packing format (for all granularities) Tensor Attributes: - _data: preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed + qdata: preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed preshuffling is specific to fbgemm kernels, see Note for motivation, detailed layout doc is WIP for bf16 activation: - group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size, + group_scale: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, dtype is the same as the original Tensor dtype - group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size, + group_zero: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, dtype is the same as the original Tensor dtype for float8 activation: group_scale: (K/group_size/8, 8, N) for 2D Tensor, (B, K/group_size/8, 8, N) for 3D Tensor @@ -57,9 +55,6 @@ class Int4PreshuffledTensor(TorchAOBaseTensor): Non-Tensor Attributes: block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) - shape_multiplier: is the multipler from _data to the real weight, since - we pack the weight for int4, for example, when we pack the last dimension for - a 2D tensor, the shape_multiplier will be [1, 2] shape: shape of the original Tensor Note on Details for preshuffle for fbgemm kernel: @@ -80,104 +75,48 @@ class Int4PreshuffledTensor(TorchAOBaseTensor): requires symmetric quantization """ - tensor_data_attrs = ["_data", "group_scale"] - tensor_attributes = ["block_size", "shape_multiplier", "shape"] + tensor_data_names = ["qdata", "group_scale"] + optional_tensor_data_names = ["group_zero", "row_scale"] + tensor_attribute_names = ["block_size", "shape"] def __new__( cls, - _data, + qdata, group_scale, group_zero, row_scale, block_size, - shape_multiplier, shape, ): kwargs = {} - kwargs["device"] = _data.device + kwargs["device"] = qdata.device kwargs["dtype"] = group_scale.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, - _data: torch.Tensor, + qdata: torch.Tensor, group_scale: torch.Tensor, group_zero: Optional[torch.Tensor], row_scale: Optional[torch.Tensor], block_size: List[int], - shape_multiplier: List[int], shape: List[int], ): # one and only one of group_scale and group_zero should be None assert group_zero is None or row_scale is None assert not (group_zero is not None and row_scale is not None) - self._data = _data + self.qdata = qdata self.group_scale = group_scale self.group_zero = group_zero self.row_scale = row_scale - self.shape_multiplier = shape_multiplier self.block_size = block_size - def __tensor_flatten__(self): - if getattr(self, "group_zero") is None: - assert getattr(self, "row_scale") is not None - return self.tensor_data_attrs + ["row_scale"], [ - getattr(self, attr) for attr in self.tensor_attributes - ] - else: - return self.tensor_data_attrs + ["group_zero"], [ - getattr(self, attr) for attr in self.tensor_attributes - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - tensors = [tensor_data_dict[name] for name in cls.tensor_data_attrs] - tensors.append(tensor_data_dict.get("group_zero", None)) - tensors.append(tensor_data_dict.get("row_scale", None)) - return cls( - *tensors, - *tensor_attributes, - ) - - def _apply_fn_to_data(self, fn): - tensors = [fn(getattr(self, name)) for name in self.tensor_data_attrs] - t1 = getattr(self, "group_zero") - tensors.append(fn(t1) if t1 is not None else None) - t2 = getattr(self, "row_scale") - tensors.append(fn(t2) if t2 is not None else None) - return self.__class__( - *tensors, - *[getattr(self, attr) for attr in self.tensor_attributes], - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(weight={self._data}, block_size={self.block_size}, " - f"shape_multiplier={self.shape_multiplier}, shape={self.shape}, device={self.device}, dtype={self.dtype}, " - f"requires_grad={self.requires_grad})" - ) - def _quantization_type(self): return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self._data.to(device), - self.group_scale.to(device), - self.group_zero.to(device) if self.group_zero is not None else None, - self.row_scale.to(device) if self.row_scale is not None else None, - self.block_size, - self.shape_multiplier, - self.shape, - ) - @classmethod - def from_float( + def from_hp( cls, w: torch.Tensor, block_size: List[int], @@ -237,17 +176,12 @@ def from_float( group_zero = None row_scale = group_zero_or_row_scale - shape_multiplier = [1] * wq.ndim - shape_multiplier[-1] = 2 - - del w return Int4PreshuffledTensor( - _data=wq, + qdata=wq, group_scale=group_scale, group_zero=group_zero, row_scale=row_scale, block_size=block_size, - shape_multiplier=shape_multiplier, shape=original_shape, ) @@ -265,15 +199,16 @@ def _(func, types, args, kwargs): orig_input_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] - wq = weight_tensor._data.contiguous() + wq = weight_tensor.qdata.contiguous() group_scale = weight_tensor.group_scale.contiguous() - # bf16 activation if weight_tensor.group_zero is not None: + # bf16 activation group_zero = weight_tensor.group_zero.contiguous() res = torch.ops.fbgemm.bf16i4bf16_shuffled( input_tensor, wq, group_scale, group_zero ) else: + # dynamically quantizes activation to fp8 assert weight_tensor.row_scale is not None row_scale = weight_tensor.row_scale.contiguous() xq, x_scale = quantize_fp8_row(input_tensor) @@ -295,16 +230,17 @@ def _(func, types, args, kwargs): ) orig_input_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] - assert weight_tensor.shape_multiplier[-1] == 2 - wq = weight_tensor._data.contiguous() + wq = weight_tensor.qdata.contiguous() group_scale = weight_tensor.group_scale.contiguous() if weight_tensor.group_zero is not None: + # bfloat16 activation group_zero = weight_tensor.group_zero.contiguous() res = torch.ops.fbgemm.bf16i4bf16_shuffled_batched( input_tensor, wq, group_scale, group_zero ) else: + # dynamically quantizes activation to fp8 assert weight_tensor.row_scale is not None row_scale = weight_tensor.row_scale.contiguous() xq, x_scale = quantize_fp8_row(input_tensor) @@ -322,125 +258,6 @@ def _(func, types, args, kwargs): return res -@implements([aten.detach.default, aten.alias.default]) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -def _same_metadata(self: "Int4PreshuffledTensor", src: "Int4PreshuffledTensor") -> bool: - return ( - isinstance(self, Int4PreshuffledTensor) - and isinstance(src, Int4PreshuffledTensor) - and self.shape == src.shape - and self._data.shape == src._data.shape - and self.group_scale.shape == src.group_scale.shape - and ( - self.group_zero.shape == src.group_zero.shape - if self.group_zero is not None - else src.group_zero is None - ) - and ( - self.row_scale.shape == src.row_scale.shape - if self.row_scale is not None - else src.row_scale is None - ) - and self.block_size == src.block_size - and self.shape_multiplier == src.shape_multiplier - ) - - -@implements(aten.copy_.default) -def _(func, types, args, kwargs): - self = args[0] - src = args[1] - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - -@implements(aten.cat.default) -def _(func, types, args, kwargs): - tensors, dim = fill_defaults(args, 2, [[], 0]) - tensor_0 = tensors[0] - if dim < 0: - dim = dim + tensor_0.ndim - - for i in range(1, len(tensors)): - assert tensor_0._data.ndim == tensors[i]._data.ndim - assert tensor_0.group_scale.ndim == tensors[i].group_scale.ndim - assert tensor_0.group_zero.ndim == tensors[i].group_zero.ndim - assert tensor_0.block_size == tensors[i].block_size - assert tensor_0.shape_multiplier == tensors[i].shape_multiplier - - _data = [t._data for t in tensors] - group_scale = [t.group_scale for t in tensors] - group_zero = [t.group_zero for t in tensors] - - # with group wise quantization, dimension of group_scale, _data and - # origianl shape will be the same, so original dim argument applies - # to both _data and group_scale - cat_data = aten.cat.default(_data, dim) - if cat_data.ndim == 2: - sz_dim = 1 - dim - else: - sz_dim = dim - - cat_group_scale = aten.cat.default(group_scale, sz_dim) - cat_group_zero = aten.cat.default(group_zero, sz_dim) - new_shape = list(cat_data.shape) - for i in range(len(tensor_0.shape_multiplier)): - new_shape[i] *= tensor_0.shape_multiplier[i] - new_shape = tuple(new_shape) - new = tensor_0.__class__( - cat_data, - cat_group_scale, - cat_group_zero, - block_size=tensor_0.block_size, - shape_multiplier=tensor_0.shape_multiplier, - shape=new_shape, - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - -@implements(aten.transpose.int) -def _(func, types, args, kwargs): - self, dim0, dim1 = args - _data = self._data.transpose(dim0, dim1).contiguous() - shape_multiplier = self.shape_multiplier.copy() - shape_multiplier[dim0], shape_multiplier[dim1] = ( - shape_multiplier[dim1], - shape_multiplier[dim0], - ) - - tensor_shape = list(_data.shape) - for i in range(len(shape_multiplier)): - tensor_shape[i] *= shape_multiplier[i] - tensor_shape = tuple(tensor_shape) - new = self.__class__( - _data, - self.group_scale, - self.group_zero, - self.block_size, - shape_multiplier, - tensor_shape, - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - Int4PreshuffledTensor.__module__ = "torchao.quantization" if TORCH_VERSION_AT_LEAST_2_5: