diff --git a/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py new file mode 100644 index 0000000000..fe0eeddd55 --- /dev/null +++ b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Float8WeightOnlyConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.testing.utils import skip_if_rocm +from torchao.utils import torch_version_at_least + +BF16_ACT_CONFIG = Float8WeightOnlyConfig( + group_size=128, + packing_format="cutlass_semi_sparse", +) + + +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +class TestFloat8SemiSparseTensor(TestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @skip_if_rocm("ROCm enablement in progress") + @parametrize("config", [BF16_ACT_CONFIG]) + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + def test_linear(self, config, sizes): + dtype = torch.bfloat16 + device = "cuda" + + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + + apply_fake_sparsity(linear) + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @skip_if_rocm("ROCm enablement in progress") + @unittest.skip("Fix later") + @parametrize("config", [BF16_ACT_CONFIG]) + def test_to_device(self, config): + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + @skip_if_rocm("ROCm enablement in progress") + @parametrize("config", [BF16_ACT_CONFIG]) + def test_module_path(self, config): + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear.cuda(), config) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestFloat8SemiSparseTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index aa19aa1890..b44bcb107c 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -78,6 +78,7 @@ quantize_affine, ) from .quantize_.workflows import ( + Float8SemiSparseTensor, Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, @@ -148,6 +149,7 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", + "Float8SemiSparseTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 139b14cf3f..d9f3026913 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1336,6 +1336,7 @@ def _int8_weight_only_quantize_tensor(weight, config): if group_size is None: group_size = weight.shape[-1] block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) + # todo: support fp8 semi-sparse new_weight = to_affine_quantized_intx( weight, mapping_type, @@ -1584,6 +1585,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype set_inductor_config: bool = True version: int = 2 + # todo: add packing format def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index c6546c55f9..9f547289f8 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -32,3 +32,4 @@ class PackingFormat(str, Enum): needed for the rest of the system to understand the specific format that's adopted. """ OPAQUE = "opaque" + # todo: add semi-sparse diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4307637f8e..7166e244a6 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,3 +1,6 @@ +from .float8.float8_semi_sparse_tensor import ( + Float8SemiSparseTensor, +) from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -38,6 +41,7 @@ "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", "Float8Tensor", + "Float8SemiSparseTensor", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", "Int4ChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py new file mode 100644 index 0000000000..78e58cbf68 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +import torch + +from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Float8SemiSparseTensor"] +aten = torch.ops.aten + + +class Float8SemiSparseTensor(TorchAOBaseTensor): + tensor_data_names = ["sparse", "scale", "meta"] + + def __new__( + cls, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = sparse.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + shape = (sparse.shape[0], 2 * sparse.shape[-1]) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + super().__init__() + self.sparse = sparse + self.meta = meta + self.scale = scale + + def _quantization_type(self): + return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + from torchao.sparsity.utils import mask_creator + + dense = w * mask_creator(w).bool() + + scale = _choose_scale_float8( + dense, + block_size=block_size, + float8_dtype=torch.float8_e4m3fn, + ) + + w_fp8 = _quantize_affine_float8( + dense, + scale=scale, + float8_dtype=torch.float8_e4m3fn, + ) + + sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8) + + return cls( + sparse, + meta, + scale, + ) + + +implements = Float8SemiSparseTensor.implements +implements_torch_function = Float8SemiSparseTensor.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 + + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + input = input_tensor.qdata + input_scale = input_tensor.scale + weight = weight_tensor.sparse + weight_meta = weight_tensor.meta + weight_scale = weight_tensor.scale + out_dtype = input_tensor.dtype + + out = rowwise_scaled_linear_sparse_cutlass_f8f8( + input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype + ) + + return out + + +Float8SemiSparseTensor.__module__ = "torchao.quantization" + +# Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8SemiSparseTensor])