Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization import (
Float8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.testing.utils import skip_if_rocm
from torchao.utils import torch_version_at_least

BF16_ACT_CONFIG = Float8WeightOnlyConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this config makes sense, it's not something we support. From what I understand this is a bf16 a + fp8 sparse weight? We only have kernel support for fp8xfp8 +2:4 sparse matmul, no support for mixed input dtypes currently.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it seems I should be mirroring test_fp8_cutlass_sparse (from test_sparse_api.py) instead
with the difference being using the new flag/config which exposes the tensor subclass being added?

group_size=128,
packing_format="cutlass_semi_sparse",
)


@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
class TestFloat8SemiSparseTensor(TestCase):
def setUp(self):
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

@skip_if_rocm("ROCm enablement in progress")
@parametrize("config", [BF16_ACT_CONFIG])
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 12),
],
)
def test_linear(self, config, sizes):
dtype = torch.bfloat16
device = "cuda"

M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)

apply_fake_sparsity(linear)
original = linear(input)
quantize_(linear, config)
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)

@skip_if_rocm("ROCm enablement in progress")
@unittest.skip("Fix later")
@parametrize("config", [BF16_ACT_CONFIG])
def test_to_device(self, config):
for device in self.GPU_DEVICES:
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, config)
linear.to(device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, config)
linear.to(device=device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, config)
linear.to(device)

@skip_if_rocm("ROCm enablement in progress")
@parametrize("config", [BF16_ACT_CONFIG])
def test_module_path(self, config):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear.cuda(), config)
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Float8SemiSparseTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Float8SemiSparseTensor'>",
)


instantiate_parametrized_tests(TestFloat8SemiSparseTensor)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
quantize_affine,
)
from .quantize_.workflows import (
Float8SemiSparseTensor,
Float8Tensor,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Expand Down Expand Up @@ -148,6 +149,7 @@
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Int4OpaqueTensor",
"Float8SemiSparseTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,7 @@ def _int8_weight_only_quantize_tensor(weight, config):
if group_size is None:
group_size = weight.shape[-1]
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size])
# todo: support fp8 semi-sparse
new_weight = to_affine_quantized_intx(
weight,
mapping_type,
Expand Down Expand Up @@ -1584,6 +1585,7 @@ class Float8WeightOnlyConfig(AOBaseConfig):
weight_dtype: torch.dtype = e4m3_dtype
set_inductor_config: bool = True
version: int = 2
# todo: add packing format

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/quantize_/common/packing_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ class PackingFormat(str, Enum):
needed for the rest of the system to understand the specific format that's adopted.
"""
OPAQUE = "opaque"
# todo: add semi-sparse
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 It seems we may want to add a packing format for sparse. Wondering if there's a preference between adding it here or in a separate file (similar to int4) for float8?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need packing format if we have a separate config? It looks like packing format is mostly to support different Int4WeightOnlyConfig kernel options (tinygemm, sparse marlin, etc).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I noticed that we seem to replace the dense weight with quantized semi-sparse in the transform Would it make more sense to integrate Float8SemiSparseTensor here rather than gating with packing-format as I proposed previously? cc @jerryzh168

4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .float8.float8_semi_sparse_tensor import (
Float8SemiSparseTensor,
)
from .float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
Expand Down Expand Up @@ -38,6 +41,7 @@
"Int4PlainInt32Tensor",
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Float8SemiSparseTensor",
"QuantizeTensorToFloat8Kwargs",
"Int4OpaqueTensor",
"Int4ChooseQParamsAlgorithm",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import torch

from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8
from torchao.quantization.quant_primitives import (
_choose_scale_float8,
_quantize_affine_float8,
)
from torchao.utils import TorchAOBaseTensor

__all__ = ["Float8SemiSparseTensor"]
aten = torch.ops.aten


class Float8SemiSparseTensor(TorchAOBaseTensor):
tensor_data_names = ["sparse", "scale", "meta"]

def __new__(
cls,
sparse: torch.Tensor,
meta: torch.Tensor,
scale: torch.Tensor,
):
kwargs = {}
kwargs["device"] = sparse.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
shape = (sparse.shape[0], 2 * sparse.shape[-1])
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
sparse: torch.Tensor,
meta: torch.Tensor,
scale: torch.Tensor,
):
super().__init__()
self.sparse = sparse
self.meta = meta
self.scale = scale

def _quantization_type(self):
return f"shape={self.shape}, device={self.device}, dtype={self.dtype}"

@classmethod
def from_hp(
cls,
w: torch.Tensor,
block_size: List[int],
):
from torchao.sparsity.utils import mask_creator

dense = w * mask_creator(w).bool()

scale = _choose_scale_float8(
dense,
block_size=block_size,
float8_dtype=torch.float8_e4m3fn,
)

w_fp8 = _quantize_affine_float8(
dense,
scale=scale,
float8_dtype=torch.float8_e4m3fn,
)

sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8)

return cls(
sparse,
meta,
scale,
)


implements = Float8SemiSparseTensor.implements
implements_torch_function = Float8SemiSparseTensor.implements_torch_function


@implements(aten.linear.default)
Copy link
Contributor

@jcaip jcaip Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll also need to make sure mm and addmm are supported ops as well. The arg order is different from linear but it should be the same logic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'm onboard with that. Mind if I add those ops in a follow-up diff after this lands?

@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8

input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)

input = input_tensor.qdata
input_scale = input_tensor.scale
weight = weight_tensor.sparse
weight_meta = weight_tensor.meta
weight_scale = weight_tensor.scale
out_dtype = input_tensor.dtype

out = rowwise_scaled_linear_sparse_cutlass_f8f8(
input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype
)

return out


Float8SemiSparseTensor.__module__ = "torchao.quantization"

# Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Float8SemiSparseTensor])
Loading