diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py new file mode 100644 index 0000000000..77238cebb6 --- /dev/null +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.quantization import ( + IntxWeightOnlyConfig, + quantize_, +) +from torchao.quantization.granularity import PerGroup +from torchao.quantization.utils import compute_error +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, +) + + +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +class TestIntxUnpackedTensor(TestCase): + def setUp(self): + self.config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + version=2, + ) + + def test_embedding(self): + dtype = torch.bfloat16 + device = "cpu" + input = torch.randint(low=0, high=128, size=(10,), device=device) + embedding = torch.nn.Embedding(128, 256, dtype=dtype, device=device) + original = embedding(input) + quantize_(embedding, self.config) + quantized = embedding(input) + error = compute_error(original, quantized) + self.assertTrue(error > 20) + + def test_linear(self): + dtype = torch.bfloat16 + device = "cpu" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, self.config) + quantized = linear(input) + error = compute_error(original, quantized) + self.assertTrue(error > 20) + + def test_slice(self): + dtype = torch.bfloat16 + device = "cpu" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, self.config) + weight1 = dummy.weight.narrow(0, 0, 64) + weight2 = dummy.weight.narrow(1, 0, 128) + + self.assertEqual(weight1.int_data, dummy.weight.int_data.narrow(0, 0, 64)) + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) + + self.assertEqual(weight2.int_data, dummy.weight.int_data.narrow(1, 0, 128)) + self.assertEqual(weight2.scale, dummy.weight.scale.narrow(1, 0, 4)) + + # check for sliced weight, before and after float8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 20 + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 15 + + def test_slice_and_copy_(self): + device = "cpu" + l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device=device) + ) + quantize_(l, self.config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + assert param.data.int_data.data_ptr() == param_data.int_data.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr() + orig_value = param.data.int_data[0][0].item() + + # dummy_l has random input (shouldn't be 0) + dummy_l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16) + quantize_(dummy_l, self.config) + quantized = dummy_l.weight + quantized = quantized.narrow(0, 0, 512) + + param_data.copy_(quantized) + + # making sure param.data is updated + assert param.data.int_data[0][0] != orig_value + + def test_to_dtype(self): + activations_bf16 = torch.randn(1, 128, dtype=torch.bfloat16) + activations_fp32 = torch.randn(1, 128, dtype=torch.float32) + activations_fp16 = torch.randn(1, 128, dtype=torch.float16) + + linear = torch.nn.Linear(128, 256) + quantize_(linear, self.config) + + linear.to(dtype=torch.float16) + linear(activations_fp16) + + linear.to(dtype=torch.float32) + linear(activations_fp32) + + linear.to(dtype=torch.bfloat16) + linear(activations_bf16) + + def test_export(self): + linear = torch.nn.Linear(128, 256) + quantize_(linear, self.config) + ep = torch.export.export(linear, (torch.randn(1, 128),)) + assert "torch.ops.torchao.dequantize_affine.default" in ep.graph_module.code + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index 1a87245ad4..459c1c5e97 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -183,10 +183,9 @@ def test_shared_embedding(self): self.assertTrue(torch.allclose(result, exported_result)) # Check the shared_embedding and linear ops use the same lifted weight - weight = "b_getattr_l__fn_____0___unembedding_packed_weights" expected_lines = [ - f"torch.ops.torchao._shared_embedding_4bit.default({weight}, 4096, 131, 4096, reshape)", - f"torch.ops.torchao._linear_8bit_act_4bit_weight.default(linear, {weight}, 4096, 131, 4096)", + "torch.ops.torchao._shared_embedding_4bit.default", + "torch.ops.torchao._linear_8bit_act_4bit_weight.default", ] for line in expected_lines: FileCheck().check_count(line, 1, exactly=True).run( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 5d79563ab1..072455bbdd 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -76,6 +76,7 @@ Float8Tensor, Int4PreshuffledTensor, Int4Tensor, + IntxUnpackedTensor, QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( @@ -563,6 +564,10 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" +def _embedding_extra_repr(self): + return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}" + + def _get_linear_subclass_inserter( constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs ): @@ -2060,6 +2065,8 @@ class IntxWeightOnlyConfig(AOBaseConfig): mapping_type: MappingType = MappingType.SYMMETRIC scale_dtype: Optional[torch.dtype] = None layout: Layout = QDQLayout() + packing_format: PackingFormat = PackingFormat.UNPACKED_TO_INT8 + version: int = 1 def __post_init__(self): assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+" @@ -2078,16 +2085,13 @@ def __post_init__(self): ) -@register_quantize_module_handler(IntxWeightOnlyConfig) -def _intx_weight_only_transform( - module: torch.nn.Module, config: IntxWeightOnlyConfig -) -> torch.nn.Module: - weight = module.weight +def _intx_weight_only_quantize_tensor(weight, config): weight_dtype = config.weight_dtype granularity = config.granularity mapping_type = config.mapping_type scale_dtype = config.scale_dtype layout = config.layout + packing_format = config.packing_format assert weight.dim() == 2, ( f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" @@ -2102,11 +2106,28 @@ def _intx_weight_only_transform( else: raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") + block_size = (1, group_size) + + if config.version == 2: + if config.packing_format == PackingFormat.UNPACKED_TO_INT8: + new_weight = IntxUnpackedTensor.from_hp( + weight, + block_size, + weight_dtype, + mapping_type=mapping_type, + ) + if scale_dtype is not None and scale_dtype != weight.dtype: + new_weight.scale = new_weight.scale.to(scale_dtype).to(weight.dtype) + return new_weight + else: + raise ValueError(f"Unsupported packing format: {packing_format}") + + # Version 1 quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype] weight = to_affine_quantized_intx( input_float=weight, mapping_type=mapping_type, - block_size=(1, group_size), + block_size=block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, @@ -2116,7 +2137,25 @@ def _intx_weight_only_transform( zero_point_domain=ZeroPointDomain.INT, _layout=layout, ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) + return weight + + +@register_quantize_module_handler(IntxWeightOnlyConfig) +def _intx_weight_only_transform( + module: torch.nn.Module, config: IntxWeightOnlyConfig +) -> torch.nn.Module: + assert hasattr(module, "weight"), ( + "applying intx weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _intx_weight_only_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + elif isinstance(module, nn.Embedding): + module.extra_repr = types.MethodType(_embedding_extra_repr, module) + return module diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 77ed2790c5..405134578c 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -30,3 +30,8 @@ class PackingFormat(str, Enum): preshuffled is referring to the preshuffled format used by fbgemm kernels """ PRESHUFFLED = "preshuffled" + + """ + Unpacked means the subbyte quantized data is stored as int8 + """ + UNPACKED_TO_INT8 = "unpacked_to_int8" diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 98480c2db2..f106b6dddb 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -8,10 +8,14 @@ from .int4.int4_tensor import ( Int4Tensor, ) +from .intx.intx_unpacked_tensor import ( + IntxUnpackedTensor, +) __all__ = [ "Int4Tensor", "Int4PreshuffledTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", + "IntxUnpackedTensor", ] diff --git a/torchao/quantization/quantize_/workflows/intx/__init__.py b/torchao/quantization/quantize_/workflows/intx/__init__.py new file mode 100644 index 0000000000..c0f1f807a5 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/intx/__init__.py @@ -0,0 +1,5 @@ +from .intx_unpacked_tensor import IntxUnpackedTensor + +__all__ = [ + "IntxUnpackedTensor", +] diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py new file mode 100644 index 0000000000..0a12236647 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional, Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_BIT_WIDTH, + _DTYPE_TO_QVALUE_BOUNDS, + MappingType, + choose_qparams_affine, + dequantize_affine, + quantize_affine, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, + fill_defaults, +) + +__all__ = [ + "IntxUnpackedTensor", +] + +aten = torch.ops.aten + +_FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32] + + +class IntxUnpackedTensor(TorchAOBaseTensor): + """ + intx quantization with unpacked format. Subbyte quantized data is represented as int8. + Quantization is represented in a decomposed way. + This format is inteded for torch.export use cases. + + Tensor Attributes: + int_data: int data for quantization. + dtype is int8 + Shape is the same as original Tensor: (n, k) for 2D tensor + scale: block scales for quantization + dtype is the same as the original Tensor dtype. + Shape is (n // block_size[0], k // block_size[1]) for 2D tensor + zero_point: block zero points for quantization + dtype is the same as the original Tensor dtype or int8 + Shape is (n // block_size[0], k // block_size[1]) for 2D tensor + + Non-Tensor Attributes: + bit_width: the bit width for quantization (can be 1 - 8) + block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) + """ + + tensor_data_names = ["int_data", "scale", "zero_point"] + tensor_attribute_names = ["bit_width", "block_size"] + + def __new__(cls, int_data, scale, zero_point, bit_width, block_size=None): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data, + scale, + zero_point, + bit_width, + block_size: Optional[Tuple[int]] = None, + ): + # Check plain data and infer block_size from shapes + if block_size is None: + assert scale.ndim == int_data.ndim + assert zero_point.ndim == int_data.ndim + block_size = [] + for i in range(int_data.ndim): + assert scale.shape[i] == zero_point.shape[i] + n_blocks = scale.shape[i] + assert int_data.shape[i] % n_blocks == 0 + block_size.append(int_data.shape[i] // n_blocks) + block_size = tuple(block_size) + else: + assert len(block_size) == int_data.ndim + n_blocks = [] + for i in range(len(block_size)): + assert int_data.shape[i] % block_size[i] == 0 + n_blocks.append(int_data.shape[i] // block_size[i]) + scale = scale.reshape(*n_blocks) + zero_point = zero_point.reshape(*n_blocks) + + assert block_size is not None + assert isinstance(block_size, tuple) + assert bit_width >= 1 and bit_width <= 8 + + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + + self.bit_width = bit_width + self.block_size = block_size + + def __repr__(self): + repr_fields = ( + self.tensor_data_names + + self.tensor_attribute_names + + ["shape", "device", "dtype", "require_grad"] + ) + inner_repr = [f"{attr}={getattr(self, attr)}" for attr in repr_fields] + inner_repr = ", ".join(inner_repr) + return f"{self.__class__.__name__}({inner_repr}))" + + def _quantization_type(self): + return f"bit_width={self.bit_width}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}" + + def _has_float_zero_point(self) -> bool: + return self.zero_point.dtype in _FLOAT_TYPES + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + dtype = kwargs.pop("dtype") + assert dtype in _FLOAT_TYPES + return self.__class__( + self.int_data.to(device), + self.scale.to(device=device, dtype=dtype), + self.zero_point.to(device=device, dtype=dtype) + if self._has_float_zero_point() + else self.zero_point.to(device), + self.bit_width, + self.block_size, + ) + + @classmethod + def from_hp( + cls, + float_tensor: torch.Tensor, + block_size: Tuple[int], + dtype: torch.dtype, + *, + mapping_type: MappingType = MappingType.SYMMETRIC, + ): + """ + Create an IntxUnpackedTensor from a high-precision tensor + """ + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[dtype] + bit_width = _DTYPE_TO_BIT_WIDTH[dtype] + scale, zero_point = choose_qparams_affine( + float_tensor, + mapping_type, + block_size, + target_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + int_data = quantize_affine( + float_tensor, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + return IntxUnpackedTensor( + int_data=int_data, + scale=scale, + zero_point=zero_point, + bit_width=bit_width, + block_size=block_size, + ) + + def get_plain(self): + return self.int_data, self.scale, self.zero_point + + def dequantize(self): + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[getattr(torch, f"int{self.bit_width}")] + return dequantize_affine( + self.int_data, + self.block_size, + self.scale, + self.zero_point, + torch.int8, + qmin, + qmax, + output_dtype=self.dtype, + ) + + +implements = IntxUnpackedTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements([torch.nn.functional.embedding, aten.embedding.default]) +def _(func, types, args, kwargs): + assert len(args) == 2 + indices, weight_tensor = ( + args[0], + args[1], + ) + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.embedding(indices, weight_tensor, **kwargs) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + + # Slicing must be compatible with the block size to make sense on the quantized tensor + # In particular both start and end must be a multiple of block_size[dim] + # Otherwise the sliced tensor cannot be represented as a IntxUnpackedTensor + # For example, if block_size = 4, we might have: + # + # int_data: i i i i | i i i i + # scale: s s + # + # If we set start = 2 and end = 8, then the int_data slice is: + # + # int_data_slice: i i (i i | i i i i) + # + # But then the block_size for the first two int_data in the slice is 2 + # and remaining blocks have size 4. This cannot be represented + # with the metadata we store in an IntxUnpackedTensor, which requires uniform blocking + + assert start % self.block_size[dim] == 0, ( + f"slice args are incompatible with blocking: start={start} must be divisible by block_size[dim]={self.block_size[dim]}" + ) + start_scale = start // self.block_size[dim] + + assert end % self.block_size[dim] == 0, ( + f"slice args are incompatible with blocking: end={end} must be divisible by block_size[dim]={self.block_size[dim]}" + ) + end_scale = end // self.block_size[dim] + + int_data = aten.slice.Tensor(self.int_data, dim, start, end, step) + scale = aten.slice.Tensor(self.scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor(self.zero_point, dim, start_scale, end_scale, step) + + new = self.__class__( + int_data, + scale, + zero_point, + self.bit_width, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +IntxUnpackedTensor.__module__ = "torchao.quantization" + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with IntxUnpackedTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([IntxUnpackedTensor])