From 9b2cfe86ebe35dbd607f08d9e0f4af53a3d76bd8 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Wed, 15 Oct 2025 12:41:38 -0700 Subject: [PATCH 1/2] Add CutlassSemiSparseFp8Tensor Summary: Moving float8 cutlass sparse layout into its own class: https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py Differential Revision: D84467190 --- .../float8/cutlass_semi_sparse_fp8_tensor.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py diff --git a/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py b/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py new file mode 100644 index 0000000000..cc3118d1fb --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import torch +from torchao.utils import TorchAOBaseTensor + +__all__ = ["CutlassSemiSparseFp8Tensor"] +aten = torch.ops.aten + +class CutlassSemiSparseFp8Tensor(TorchAOBaseTensor): + tensor_data_names = ["sparse", "scale", "meta"] + + def __new__( + cls, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = sparse.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + shape = (sparse.shape[0], 2 * sparse.shape[-1]) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + + def __init__( + self, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + super().__init__() + self.sparse = sparse + self.meta = meta + self.scale = scale + + def _quantization_type(self): + return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" + + + @classmethod + def from_hp( + ): + raise NotImplementedError("CutlassSemiSparseFp8Tensor.from_hp is not implemented yet") + + +implements = CutlassSemiSparseFp8Tensor.implements +implements_torch_function = CutlassSemiSparseFp8Tensor.implements_torch_function + +CutlassSemiSparseFp8Tensor.__module__ = "torchao.quantization" + +# Allow a model with CutlassSemiSparseFp8Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([CutlassSemiSparseFp8Tensor]) From fc80e43499ce659405fcb49c71f9aeb95996b5d7 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Wed, 15 Oct 2025 15:18:44 -0700 Subject: [PATCH 2/2] Implement packing and linear Signed-off-by: Benji Beck --- .../float8/test_float8_semi_sparse.py | 108 +++++++++++++++++ torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 2 + .../quantize_/common/packing_format.py | 1 + .../quantize_/workflows/__init__.py | 4 + .../float8/cutlass_semi_sparse_fp8_tensor.py | 56 --------- .../float8/float8_semi_sparse_tensor.py | 114 ++++++++++++++++++ 7 files changed, 231 insertions(+), 56 deletions(-) create mode 100644 test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py delete mode 100644 torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py create mode 100644 torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py diff --git a/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py new file mode 100644 index 0000000000..fe0eeddd55 --- /dev/null +++ b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Float8WeightOnlyConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.testing.utils import skip_if_rocm +from torchao.utils import torch_version_at_least + +BF16_ACT_CONFIG = Float8WeightOnlyConfig( + group_size=128, + packing_format="cutlass_semi_sparse", +) + + +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +class TestFloat8SemiSparseTensor(TestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @skip_if_rocm("ROCm enablement in progress") + @parametrize("config", [BF16_ACT_CONFIG]) + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + def test_linear(self, config, sizes): + dtype = torch.bfloat16 + device = "cuda" + + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + + apply_fake_sparsity(linear) + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @skip_if_rocm("ROCm enablement in progress") + @unittest.skip("Fix later") + @parametrize("config", [BF16_ACT_CONFIG]) + def test_to_device(self, config): + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + @skip_if_rocm("ROCm enablement in progress") + @parametrize("config", [BF16_ACT_CONFIG]) + def test_module_path(self, config): + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear.cuda(), config) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestFloat8SemiSparseTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index aa19aa1890..b44bcb107c 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -78,6 +78,7 @@ quantize_affine, ) from .quantize_.workflows import ( + Float8SemiSparseTensor, Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, @@ -148,6 +149,7 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", + "Float8SemiSparseTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 139b14cf3f..d9f3026913 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1336,6 +1336,7 @@ def _int8_weight_only_quantize_tensor(weight, config): if group_size is None: group_size = weight.shape[-1] block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) + # todo: support fp8 semi-sparse new_weight = to_affine_quantized_intx( weight, mapping_type, @@ -1584,6 +1585,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype set_inductor_config: bool = True version: int = 2 + # todo: add packing format def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index c6546c55f9..9f547289f8 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -32,3 +32,4 @@ class PackingFormat(str, Enum): needed for the rest of the system to understand the specific format that's adopted. """ OPAQUE = "opaque" + # todo: add semi-sparse diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4307637f8e..7166e244a6 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,3 +1,6 @@ +from .float8.float8_semi_sparse_tensor import ( + Float8SemiSparseTensor, +) from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -38,6 +41,7 @@ "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", "Float8Tensor", + "Float8SemiSparseTensor", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", "Int4ChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py b/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py deleted file mode 100644 index cc3118d1fb..0000000000 --- a/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import torch -from torchao.utils import TorchAOBaseTensor - -__all__ = ["CutlassSemiSparseFp8Tensor"] -aten = torch.ops.aten - -class CutlassSemiSparseFp8Tensor(TorchAOBaseTensor): - tensor_data_names = ["sparse", "scale", "meta"] - - def __new__( - cls, - sparse: torch.Tensor, - meta: torch.Tensor, - scale: torch.Tensor, - ): - kwargs = {} - kwargs["device"] = sparse.device - kwargs["dtype"] = scale.dtype - kwargs["requires_grad"] = False - shape = (sparse.shape[0], 2 * sparse.shape[-1]) - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - - def __init__( - self, - sparse: torch.Tensor, - meta: torch.Tensor, - scale: torch.Tensor, - ): - super().__init__() - self.sparse = sparse - self.meta = meta - self.scale = scale - - def _quantization_type(self): - return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" - - - @classmethod - def from_hp( - ): - raise NotImplementedError("CutlassSemiSparseFp8Tensor.from_hp is not implemented yet") - - -implements = CutlassSemiSparseFp8Tensor.implements -implements_torch_function = CutlassSemiSparseFp8Tensor.implements_torch_function - -CutlassSemiSparseFp8Tensor.__module__ = "torchao.quantization" - -# Allow a model with CutlassSemiSparseFp8Tensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([CutlassSemiSparseFp8Tensor]) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py new file mode 100644 index 0000000000..78e58cbf68 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +import torch + +from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Float8SemiSparseTensor"] +aten = torch.ops.aten + + +class Float8SemiSparseTensor(TorchAOBaseTensor): + tensor_data_names = ["sparse", "scale", "meta"] + + def __new__( + cls, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = sparse.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + shape = (sparse.shape[0], 2 * sparse.shape[-1]) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + super().__init__() + self.sparse = sparse + self.meta = meta + self.scale = scale + + def _quantization_type(self): + return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + from torchao.sparsity.utils import mask_creator + + dense = w * mask_creator(w).bool() + + scale = _choose_scale_float8( + dense, + block_size=block_size, + float8_dtype=torch.float8_e4m3fn, + ) + + w_fp8 = _quantize_affine_float8( + dense, + scale=scale, + float8_dtype=torch.float8_e4m3fn, + ) + + sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8) + + return cls( + sparse, + meta, + scale, + ) + + +implements = Float8SemiSparseTensor.implements +implements_torch_function = Float8SemiSparseTensor.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 + + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + input = input_tensor.qdata + input_scale = input_tensor.scale + weight = weight_tensor.sparse + weight_meta = weight_tensor.meta + weight_scale = weight_tensor.scale + out_dtype = input_tensor.dtype + + out = rowwise_scaled_linear_sparse_cutlass_f8f8( + input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype + ) + + return out + + +Float8SemiSparseTensor.__module__ = "torchao.quantization" + +# Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8SemiSparseTensor])