-
Notifications
You must be signed in to change notification settings - Fork 349
[WIP] Move float8 cutlass sparse layout to Float8SemiSparseTensor #3182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import tempfile | ||
import unittest | ||
|
||
import torch | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
|
||
from torchao.quantization import ( | ||
Float8WeightOnlyConfig, | ||
quantize_, | ||
) | ||
from torchao.quantization.utils import compute_error | ||
from torchao.sparsity.sparse_api import apply_fake_sparsity | ||
from torchao.testing.utils import skip_if_rocm | ||
from torchao.utils import torch_version_at_least | ||
|
||
BF16_ACT_CONFIG = Float8WeightOnlyConfig( | ||
group_size=128, | ||
packing_format="cutlass_semi_sparse", | ||
) | ||
|
||
|
||
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
class TestFloat8SemiSparseTensor(TestCase): | ||
def setUp(self): | ||
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] | ||
|
||
@skip_if_rocm("ROCm enablement in progress") | ||
@parametrize("config", [BF16_ACT_CONFIG]) | ||
@parametrize( | ||
"sizes", | ||
[ | ||
((128,), 256, 128), | ||
((32, 128), 512, 128), | ||
((2, 32, 128), 256, 12), | ||
], | ||
) | ||
def test_linear(self, config, sizes): | ||
dtype = torch.bfloat16 | ||
device = "cuda" | ||
|
||
M, N, K = sizes | ||
input = torch.randn(*M, K, dtype=dtype, device=device) | ||
linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
|
||
apply_fake_sparsity(linear) | ||
original = linear(input) | ||
quantize_(linear, config) | ||
quantized = linear(input) | ||
self.assertTrue(compute_error(original, quantized) > 20) | ||
|
||
compiled_linear = torch.compile(linear) | ||
quantized_and_compiled = compiled_linear(input) | ||
self.assertTrue(compute_error(original, quantized_and_compiled) > 20) | ||
|
||
@skip_if_rocm("ROCm enablement in progress") | ||
@unittest.skip("Fix later") | ||
@parametrize("config", [BF16_ACT_CONFIG]) | ||
def test_to_device(self, config): | ||
for device in self.GPU_DEVICES: | ||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
quantize_(linear, config) | ||
linear.to(device) | ||
|
||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
quantize_(linear, config) | ||
linear.to(device=device) | ||
|
||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
quantize_(linear, config) | ||
linear.to(device) | ||
|
||
@skip_if_rocm("ROCm enablement in progress") | ||
@parametrize("config", [BF16_ACT_CONFIG]) | ||
def test_module_path(self, config): | ||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) | ||
quantize_(linear.cuda(), config) | ||
self.assertEqual( | ||
str(type(linear.weight)), | ||
"<class 'torchao.quantization.Float8SemiSparseTensor'>", | ||
) | ||
|
||
with tempfile.NamedTemporaryFile() as f: | ||
torch.save(linear.state_dict(), f) | ||
f.seek(0) | ||
state_dict = torch.load(f) | ||
self.assertEqual( | ||
str(type(state_dict["weight"])), | ||
"<class 'torchao.quantization.Float8SemiSparseTensor'>", | ||
) | ||
|
||
|
||
instantiate_parametrized_tests(TestFloat8SemiSparseTensor) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,3 +32,4 @@ class PackingFormat(str, Enum): | |
needed for the rest of the system to understand the specific format that's adopted. | ||
""" | ||
OPAQUE = "opaque" | ||
# todo: add semi-sparse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jerryzh168 It seems we may want to add a packing format for sparse. Wondering if there's a preference between adding it here or in a separate file (similar to int4) for float8? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need packing format if we have a separate config? It looks like packing format is mostly to support different Int4WeightOnlyConfig kernel options (tinygemm, sparse marlin, etc). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I noticed that we seem to replace the dense weight with quantized semi-sparse in the transform Would it make more sense to integrate Float8SemiSparseTensor here rather than gating with packing-format as I proposed previously? cc @jerryzh168 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import List | ||
|
||
import torch | ||
|
||
from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 | ||
from torchao.quantization.quant_primitives import ( | ||
_choose_scale_float8, | ||
_quantize_affine_float8, | ||
) | ||
from torchao.utils import TorchAOBaseTensor | ||
|
||
__all__ = ["Float8SemiSparseTensor"] | ||
aten = torch.ops.aten | ||
|
||
|
||
class Float8SemiSparseTensor(TorchAOBaseTensor): | ||
tensor_data_names = ["sparse", "scale", "meta"] | ||
|
||
def __new__( | ||
cls, | ||
sparse: torch.Tensor, | ||
meta: torch.Tensor, | ||
scale: torch.Tensor, | ||
): | ||
kwargs = {} | ||
kwargs["device"] = sparse.device | ||
kwargs["dtype"] = scale.dtype | ||
kwargs["requires_grad"] = False | ||
shape = (sparse.shape[0], 2 * sparse.shape[-1]) | ||
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
||
def __init__( | ||
self, | ||
sparse: torch.Tensor, | ||
meta: torch.Tensor, | ||
scale: torch.Tensor, | ||
): | ||
super().__init__() | ||
self.sparse = sparse | ||
self.meta = meta | ||
self.scale = scale | ||
|
||
def _quantization_type(self): | ||
return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" | ||
|
||
@classmethod | ||
def from_hp( | ||
cls, | ||
w: torch.Tensor, | ||
block_size: List[int], | ||
): | ||
from torchao.sparsity.utils import mask_creator | ||
|
||
dense = w * mask_creator(w).bool() | ||
|
||
scale = _choose_scale_float8( | ||
dense, | ||
block_size=block_size, | ||
float8_dtype=torch.float8_e4m3fn, | ||
) | ||
|
||
w_fp8 = _quantize_affine_float8( | ||
dense, | ||
scale=scale, | ||
float8_dtype=torch.float8_e4m3fn, | ||
) | ||
|
||
sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8) | ||
|
||
return cls( | ||
sparse, | ||
meta, | ||
scale, | ||
) | ||
|
||
|
||
implements = Float8SemiSparseTensor.implements | ||
implements_torch_function = Float8SemiSparseTensor.implements_torch_function | ||
|
||
|
||
@implements(aten.linear.default) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll also need to make sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, I'm onboard with that. Mind if I add those ops in a follow-up diff after this lands? |
||
@implements_torch_function(torch.nn.functional.linear) | ||
def _(func, types, args, kwargs): | ||
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 | ||
|
||
input_tensor, weight_tensor, bias = ( | ||
args[0], | ||
args[1], | ||
args[2] if len(args) > 2 else None, | ||
) | ||
|
||
input = input_tensor.qdata | ||
input_scale = input_tensor.scale | ||
weight = weight_tensor.sparse | ||
weight_meta = weight_tensor.meta | ||
weight_scale = weight_tensor.scale | ||
out_dtype = input_tensor.dtype | ||
|
||
out = rowwise_scaled_linear_sparse_cutlass_f8f8( | ||
input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype | ||
) | ||
|
||
return out | ||
|
||
|
||
Float8SemiSparseTensor.__module__ = "torchao.quantization" | ||
|
||
# Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True` | ||
torch.serialization.add_safe_globals([Float8SemiSparseTensor]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this config makes sense, it's not something we support. From what I understand this is a bf16 a + fp8 sparse weight? We only have kernel support for fp8xfp8 +2:4 sparse matmul, no support for mixed input dtypes currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, it seems I should be mirroring test_fp8_cutlass_sparse (from test_sparse_api.py) instead
with the difference being using the new flag/config which exposes the tensor subclass being added?