diff --git a/.github/workflows/1xH100_tests.yml b/.github/workflows/1xH100_tests.yml index 18f1ff9cd4..b5e312bf5b 100644 --- a/.github/workflows/1xH100_tests.yml +++ b/.github/workflows/1xH100_tests.yml @@ -25,7 +25,7 @@ jobs: include: - name: H100 runs-on: linux.aws.h100 - torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126' + torch-spec: '--pre torch torchvision torchaudio fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu126' gpu-arch-type: "cuda" gpu-arch-version: "12.4" permissions: @@ -33,7 +33,7 @@ jobs: contents: read uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: - timeout: 60 + timeout: 90 runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} @@ -46,8 +46,8 @@ jobs: pip install uv pip install ${{ matrix.torch-spec }} uv pip install -r dev-requirements.txt - uv pip install vllm pip install . pytest test/integration --verbose -s pytest test/dtypes/test_affine_quantized_float.py --verbose -s + python test/quantization/quantize_/workflows/float8/test_float8_tensor.py ./test/float8/test_everything_single_gpu.sh diff --git a/.github/workflows/1xL4_tests.yml b/.github/workflows/1xL4_tests.yml index cf4bf22423..39175ed0f9 100644 --- a/.github/workflows/1xL4_tests.yml +++ b/.github/workflows/1xL4_tests.yml @@ -46,8 +46,8 @@ jobs: pip install uv pip install ${{ matrix.torch-spec }} uv pip install -r dev-requirements.txt - uv pip install vllm pip install . pytest test/integration --verbose -s pytest test/dtypes/test_affine_quantized_float.py --verbose -s ./test/float8/test_everything_single_gpu.sh + python test/quantization/quantize_/workflows/float8/test_float8_tensor.py diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 56010d7d1b..1f88bdd65d 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -737,6 +737,7 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode): Verify that float8 quantization + torch.compile results in the expected number of kernels in the GPU trace. """ + torch.compiler.reset() M, K, N = 128, 256, 512 m = torch.nn.Sequential( diff --git a/test/dtypes/test_fbgemm_fp8.py b/test/dtypes/test_fbgemm_fp8.py deleted file mode 100644 index ea869a1c39..0000000000 --- a/test/dtypes/test_fbgemm_fp8.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from torch.testing._internal.common_utils import ( - TestCase, - run_tests, -) - -from torchao.float8.config import e4m3_dtype -from torchao.quantization import ( - FbgemmConfig, - quantize_, -) -from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, - is_sm_at_least_90, -) - - -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") -@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") -class TestFbgemmFp8Tensor(TestCase): - def setUp(self): - self.config = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=e4m3_dtype, - output_dtype=torch.bfloat16, - ) - self.bmm_config = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=e4m3_dtype, - output_dtype=torch.bfloat16, - transpose_input=True, - ) - self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] - - def test_linear(self): - dtype = torch.bfloat16 - device = "cuda" - input = torch.randn(1, 128, dtype=dtype, device=device) - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - original = linear(input) - quantize_(linear, self.config) - quantized = linear(input) - self.assertTrue(compute_error(original, quantized) > 20) - - def test_slice(self): - dtype = torch.bfloat16 - device = "cuda" - dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) - dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) - dummy1.weight = torch.nn.Parameter( - dummy.weight.narrow(0, 0, 64), requires_grad=False - ) - dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) - dummy2.weight = torch.nn.Parameter( - dummy.weight.narrow(1, 0, 128), requires_grad=False - ) - - quantize_(dummy, self.config) - weight1 = dummy.weight.narrow(0, 0, 64) - weight2 = dummy.weight.narrow(1, 0, 128) - self.assertEqual(weight1.float8_data, dummy.weight.float8_data.narrow(0, 0, 64)) - self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) - self.assertEqual( - weight2.float8_data, dummy.weight.float8_data.narrow(1, 0, 128) - ) - self.assertEqual(weight2.scale, dummy.weight.scale) - - # check for sliced weight, before and after float8 quantization - # does not differ too much - input = torch.randn(2, 256, dtype=dtype, device=device) - res_ref = dummy1(input) - dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) - res = dummy(input) - assert compute_error(res, res_ref) > 25 - - input = torch.randn(2, 128, dtype=dtype, device=device) - res_ref = dummy2(input) - dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) - res = dummy(input) - assert compute_error(res, res_ref) > 15 - - def test_slice_and_copy_(self): - l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - l.weight = torch.nn.Parameter( - torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") - ) - quantize_(l, self.config) - param = l.weight - param_data = param.data - param_data = param_data.narrow(0, 0, 512) - assert param.data.float8_data.data_ptr() == param_data.float8_data.data_ptr() - assert param.data.scale.data_ptr() == param_data.scale.data_ptr() - orig_value = param.data.float8_data[0][0].item() - - # dummy_l has random input (shouldn't be 0) - dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - quantize_(dummy_l, self.config) - quantized = dummy_l.weight - quantized = quantized.narrow(0, 0, 512) - - param_data.copy_(quantized) - - # making sure param.data is updated - assert param.data.float8_data[0][0] != orig_value - - def test_bmm(self): - class M(torch.nn.Module): - def __init__(self, weight): - super().__init__() - self.weight = weight - - def forward(self, x): - return torch.bmm(x, self.weight) - - dtype = torch.bfloat16 - device = "cuda" - input = torch.randn(10, 32, 128, dtype=dtype, device=device) - weight = torch.randn(10, 128, 256, dtype=dtype, device=device) - m = M(weight).eval() - original = m(input) - # we need to transpose the weight first for bmm - m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) - quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) - quantized = m(input) - self.assertTrue(compute_error(original, quantized) > 20) - - def test_to_device(self): - for device in self.GPU_DEVICES: - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) - linear.to(device) - - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) - linear.to(device=device) - - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) - linear.to(device) - - -if __name__ == "__main__": - run_tests() diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py new file mode 100644 index 0000000000..e53f1412c2 --- /dev/null +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -0,0 +1,578 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest +from contextlib import nullcontext +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.prototype.moe_quant.utils import MoEQuantConfig +from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + PerRow, + PerTensor, + quantize_, +) +from torchao.quantization.quantize_.common import KernelPreference +from torchao.quantization.utils import compute_error +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, + _is_fbgemm_genai_gpu_available, + is_sm_at_least_89, + is_sm_at_least_90, +) + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 128 + + +class Experts(nn.Module): + def __init__( + self, + num_local_experts: int, + dim: int, + hidden_dim: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__() + + self.num_local_experts = num_local_experts + self.dim = dim + + self.w1: nn.Parameter = nn.Parameter( + torch.randn( + num_local_experts, + dim, + hidden_dim, + dtype=dtype, + device=device, + ) + ) + + self.w2: nn.Parameter = nn.Parameter( + torch.randn( + num_local_experts, + hidden_dim, + dim, + dtype=dtype, + device=device, + ) + ) + + self.w3: nn.Parameter = nn.Parameter( + torch.randn( + num_local_experts, + dim, + hidden_dim, + dtype=dtype, + device=device, + ) + ) + + def forward( + self, + routed_in_egD: torch.Tensor, # noqa: N803 + ) -> torch.Tensor: + e = self.num_local_experts + D = self.dim + + x_egD = routed_in_egD.view(e, -1, D) + + middle_out_egF = F.silu(torch.bmm(x_egD, self.w1)) * torch.bmm(x_egD, self.w3) + out_egD = torch.bmm(middle_out_egF, self.w2) + out_egD = out_egD.view(-1, D) + + return out_egD + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) + self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") +class TestFloat8Tensor(TestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("mode", ["dynamic", "weight-only"]) + @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "kernel_preference", + [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], + ) + # Inputs are (M,..), K, N + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ], + ) + def test_fp8_linear_variants( + self, + dtype: torch.dtype, + mode: str, + compile: bool, + granularity, + kernel_preference: KernelPreference, + sizes: Tuple, + ): + error_message = None + if isinstance(granularity, PerRow): + if mode == "dynamic" and dtype != torch.bfloat16: + error_message = "PerRow quantization only works for bfloat16 precision" + + if mode == "weight-only" and kernel_preference != KernelPreference.AUTO: + return unittest.skip( + "weight only quant only uses AUTO kernel preference right now" + ) + + if kernel_preference == KernelPreference.FBGEMM and ( + (not _is_fbgemm_genai_gpu_available()) or (not is_sm_at_least_90()) + ): + return unittest.skip( + "Requires fbgemm_gpu_genai to run fbgemm kernel preference test" + ) + + error_context = ( + self.assertRaisesRegex(AssertionError, error_message) + if error_message + else nullcontext() + ) + + with error_context: + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + + # Create a linear layer with bfloat16 dtype + model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + + quantized_model = copy.deepcopy(model) + + if mode == "dynamic": + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, + kernel_preference=kernel_preference, + VERSION=2, + ) + else: + assert mode == "weight-only", f"Unsupported mode: {mode}" + config = Float8WeightOnlyConfig() + + quantize_(quantized_model, config) + + if compile: + quantized_model = torch.compile(quantized_model, fullgraph=True) + + output_original = model(input_tensor) + output_quantized = quantized_model(input_tensor) + + error = compute_error(output_original, output_quantized) + assert compute_error(output_original, output_quantized) > 20, ( + f"Quantization error is too high got a SQNR of {error}" + ) + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @unittest.skipIf( + not is_sm_at_least_90(), + "Failing in SM89 right now: " + "AssertionError: tensor(False, device='cuda:0') is not true : sqnr: -2.90625, will fix a bit later", + ) + def test_slice(self, granularity): + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, VERSION=2 + ) + dtype = torch.bfloat16 + device = "cuda" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, config) + weight1 = dummy.weight.clone().narrow(0, 0, 64) + weight2 = dummy.weight.clone().narrow(1, 0, 128) + self.assertEqual( + weight1.qdata, + dummy.weight.qdata.narrow(0, 0, 64), + ) + self.assertEqual( + weight2.qdata, + dummy.weight.qdata.narrow(1, 0, 128), + ) + if isinstance(granularity, PerRow): + self.assertEqual( + weight1.scale, + dummy.weight.scale.narrow(0, 0, 64), + ) + self.assertEqual( + weight2.scale, + dummy.weight.scale, + ) + else: + self.assertEqual( + weight1.scale, + dummy.weight.scale, + ) + self.assertEqual( + weight2.scale, + dummy.weight.scale, + ) + + # check for sliced weight, before and after float8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) + res = dummy(input) + sqnr = compute_error(res, res_ref) + self.assertTrue(sqnr > 25, f"sqnr: {sqnr}") + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) + res = dummy(input) + sqnr = compute_error(res, res_ref) + self.assertTrue(sqnr > 15, f"sqnr: {sqnr}") + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_slice_preserves_aliasing(self, granularity): + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, VERSION=2 + ) + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + # Making sure the aliasing is preserved in sliced quantized Tensor + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_slice_and_copy_similar_to_vllm(self, granularity): + # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly + # the test is similar to the linked code, but with some hardcoded arguments + # and does not use tensor parallelism + + dtype = torch.bfloat16 + device = "cuda" + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, VERSION=2 + ) + l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype) + quantize_(l, config) + + # high level, we do a narrow for both param.data and the loaded_weights + # and do inplace copy_ to copy from the loaded_weights into param.data + + # simulate loaded_weight + dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + # making the weight different + dummy_l.weight = torch.nn.Parameter( + dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), + requires_grad=False, + ) + quantize_(dummy_l, config) + + output_dim = 0 + shard_size = 512 + for tp_rank in [0, 1]: + start_idx = tp_rank * shard_size + param = l.weight + param_data = param.data + param_data = param_data.narrow(output_dim, start_idx, shard_size) + orig_value = param_data.qdata[0][0].item() + loaded_weight = dummy_l.weight + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0] + assert orig_value != loaded_weight.qdata[0][0] + param_data.copy_(loaded_weight) + # making sure param.data is updated to loaded_weight + assert param_data.qdata[0][0] == loaded_weight.qdata[0][0] + assert param_data.scale[0] == loaded_weight.scale[0] + + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def test_bmm(self): + # only support per row quantization + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), VERSION=2 + ) + + class M(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(10, 32, 128, dtype=dtype, device=device) + weight = torch.randn(10, 128, 256, dtype=dtype, device=device) + m = M(weight).eval() + original = m(input) + # we need to transpose the weight first for bmm + m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) + quantize_(m, config, filter_fn=lambda x, fqn: True) + quantized = m(input) + self.assertTrue(compute_error(original, quantized) > 20) + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ((2, 32, 128), 64, 256), + ], + ) + def test_to_device(self, granularity, sizes): + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, VERSION=2 + ) + M, N, K = sizes + dtype = torch.bfloat16 + for device in self.GPU_DEVICES: + input_tensor = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) + linear.to(device) + linear(input_tensor) + + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) + linear.to(device=device) + linear(input_tensor) + + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) + linear.to(device) + linear(input_tensor) + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ((2, 32, 128), 64, 256), + ], + ) + def test_cat(self, granularity, sizes): + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, VERSION=2 + ) + dtype = torch.bfloat16 + device = "cuda" + M, N, K = sizes + linear1 = torch.nn.Linear(K, N, dtype=dtype, device=device) + linear2 = torch.nn.Linear(K, N, dtype=dtype, device=device) + input_cat1 = torch.randn(*M, K, dtype=dtype, device=device) + + cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + dummy_linear1 = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) + + dummy_linear1.weight = torch.nn.Parameter(cat_weight1) + quantize_(dummy_linear1, config) + + quantize_(linear1, config) + quantize_(linear2, config) + + cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + self.assertTrue(cat_qweight1.shape, (2 * N, K)) + self.assertEqual( + dummy_linear1.weight.qdata, + cat_qweight1.qdata, + ) + self.assertEqual( + dummy_linear1.weight.scale, + cat_qweight1.scale, + ) + + # making sure cat_qweight1 can be used for inference + dummy_linear1.weight = torch.nn.Parameter(cat_qweight1, requires_grad=False) + dummy_linear1(input_cat1) + + # align the scale before concatenation + linear2.weight.scale = linear1.weight.scale + cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1) + self.assertTrue(cat_qweight2.shape, (N, 2 * K)) + ref_data = torch.cat( + [ + linear1.weight.qdata, + linear2.weight.qdata, + ], + dim=1, + ) + ref_scale = linear1.weight.scale + self.assertEqual(cat_qweight2.qdata, ref_data) + self.assertEqual(cat_qweight2.scale, ref_scale) + + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def test_moe_weight_reshape_ops(self): + """This is testing the op call sequence in saving and loading quantization + checkpoints in llama-models for llama4 + (https://github.com/meta-llama/llama-models/tree/main/models/llama4) + """ + # only per row quantization is supported for bmm + granularity = PerRow() + dtype = torch.bfloat16 + device = "cuda" + + bmm_config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, VERSION=2 + ) + moe_config = MoEQuantConfig(bmm_config) + + batch_size = 4 + num_experts = 2 + input_dim = 64 + dim = 128 + hidden_dim = 256 + + moe1 = Experts(num_experts, dim, hidden_dim, dtype, device) + moe2 = Experts(num_experts, dim, hidden_dim, dtype, device) + moe_combined = Experts(num_experts, dim, 2 * hidden_dim, dtype, device) + input = torch.randn(batch_size, input_dim, dim, dtype=dtype, device=device) + + moes = [moe1, moe2] + + for moe in moes: + moe(input) + + def filter_fn(module, fqn): + return isinstance(module, Experts) + + # need to transpose before quantizing + moe.w1 = torch.nn.Parameter( + moe.w1.transpose(1, 2).contiguous(), requires_grad=False + ) + moe.w2 = torch.nn.Parameter( + moe.w2.transpose(1, 2).contiguous(), requires_grad=False + ) + moe.w3 = torch.nn.Parameter( + moe.w3.transpose(1, 2).contiguous(), requires_grad=False + ) + + quantize_(moe, moe_config, filter_fn=filter_fn) + + # make sure it runs + before = moe(input) + + # transposing for resharding support since only 2D resharding is supported + new_last_dim = moe.w1.shape[-2] + moe.w1 = torch.nn.Parameter( + moe.w1.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False + ) + new_last_dim = moe.w2.shape[-2] + moe.w2 = torch.nn.Parameter( + moe.w2.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False + ) + new_last_dim = moe.w3.shape[-2] + moe.w3 = torch.nn.Parameter( + moe.w3.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False + ) + + moe.w1 = torch.nn.Parameter( + moe.w1.unflatten(0, (num_experts, -1)).squeeze(dim=0), + requires_grad=False, + ) + moe.w2 = torch.nn.Parameter( + moe.w2.unflatten(0, (num_experts, -1)).squeeze(dim=0), + requires_grad=False, + ) + moe.w3 = torch.nn.Parameter( + moe.w3.unflatten(0, (num_experts, -1)).squeeze(dim=0), + requires_grad=False, + ) + + # transpose again to recover the original weights + moe.w1 = torch.nn.Parameter(moe.w1.transpose(1, 2), requires_grad=False) + moe.w2 = torch.nn.Parameter(moe.w2.transpose(1, 2), requires_grad=False) + moe.w3 = torch.nn.Parameter(moe.w3.transpose(1, 2), requires_grad=False) + + # make sure it runs + after = moe(input) + + self.assertEqual(before, after) + + state_dicts = [moe1.state_dict(), moe2.state_dict()] + # align the scale parameter so they can be concatenated + for key in ["w1", "w2", "w3"]: + weights = [st[key] for st in state_dicts] + for i in range(1, len(weights)): + weights[i].scale = weights[0].scale + + def process_key(key: str) -> torch.Tensor: + tensors = [s[key] for s in state_dicts] + # Note: we have a hacky implementation for cat in user codebase + # since it is not implemented correctly before + if key == "w2": + return torch.cat(tensors, dim=-1) + else: + return torch.cat(tensors, dim=-2) + + new_state_dict = {} + for key in ["w1", "w2", "w3"]: + new_state_dict[key] = process_key(key) + + moe_combined.w1 = torch.nn.Parameter( + moe_combined.w1.transpose(1, 2), requires_grad=False + ) + moe_combined.w2 = torch.nn.Parameter( + moe_combined.w2.transpose(1, 2), requires_grad=False + ) + moe_combined.w3 = torch.nn.Parameter( + moe_combined.w3.transpose(1, 2), requires_grad=False + ) + moe_combined.load_state_dict(new_state_dict, assign=True) + # make sure it runs + moe_combined(input) + + +common_utils.instantiate_parametrized_tests(TestFloat8Tensor) + +if __name__ == "__main__": + run_tests() diff --git a/torchao/core/config.py b/torchao/core/config.py index 0985b1af6a..b7e85d6b3d 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -203,6 +203,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: "torchao.prototype.mx_formats", "torchao.dtypes", "torchao.prototype.awq", + "torchao.quantization.quantize_.common", } diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index f87d038430..2a56f9cbcb 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -88,6 +88,7 @@ quantize_affine, ) from .quantize_.workflows import ( + Float8Tensor, Int4PreshuffledTensor, ) from .smoothquant import ( @@ -154,6 +155,7 @@ "FbgemmConfig", # tensor subclasses "Int4PreshuffledTensor", + "Float8Tensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 33439552a0..42088a28bc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -26,7 +26,9 @@ import torch.nn.utils.parametrize as parametrize import torchao -from torchao.core.config import AOBaseConfig +from torchao.core.config import ( + AOBaseConfig, +) from torchao.dtypes import ( AffineQuantizedTensor, CutlassInt4PackedLayout, @@ -67,8 +69,13 @@ LinearActivationWeightObservedTensor, ) from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size +from torchao.quantization.quantize_.common import ( + KernelPreference, +) from torchao.quantization.quantize_.workflows import ( + Float8Tensor, Int4PreshuffledTensor, + QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, @@ -1482,6 +1489,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor Note: The actual matmul will be computed in original precision of the weight tensor. @@ -1489,6 +1497,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype set_inductor_config: bool = True + VERSION: int = 1 # for BC @@ -1496,16 +1505,23 @@ class Float8WeightOnlyConfig(AOBaseConfig): def _float8_weight_only_quant_tensor(weight, config): - from torchao.dtypes import to_affine_quantized_floatx + if config.VERSION == 1: + from torchao.dtypes import to_affine_quantized_floatx - block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) - new_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=config.weight_dtype, - scale_dtype=None, - _layout=Float8Layout(mm_config=None), - ) + block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) + new_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=config.weight_dtype, + scale_dtype=None, + _layout=Float8Layout(mm_config=None), + ) + else: + assert config.VERSION == 2, f"Unexpected version: {config.VERSION}" + weight_dtype = config.weight_dtype + new_weight = Float8Tensor.to_float8( + weight, float8_dtype=weight_dtype, granularity=PerRow() + ) return new_weight @@ -1603,13 +1619,17 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): Args: activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. - granularity: + granularity (Optional[Union[FP8Granularity, List[FP8Granularity]]]): The granularity for quantization. Can be either a single granularity (applied to both activations and weights) or a tuple of two granularities (one for activations, one for weights). If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And only PerTensor and PerRow are supported. mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + activation_value_lb (Optional[float]): the lower bound for activation value for calculating scale + activation_value_ub (Optional[float]): the upper bound for activation value for calculating scale + kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. by defalut (KernelPreference.AUTO) it will be chosen for user based on hardware or other information, this only needs to be set in weight set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor """ @@ -1617,7 +1637,11 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None mm_config: Optional[Float8MMConfig] = None + activation_value_lb: Optional[float] = None + activation_value_ub: Optional[float] = None + kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True + VERSION: int = 1 def __post_init__(self): if self.mm_config is None: @@ -1638,6 +1662,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): weight_dtype = config.weight_dtype granularity = config.granularity mm_config = config.mm_config + activation_value_lb = config.activation_value_lb + activation_value_ub = config.activation_value_ub + kernel_preference = config.kernel_preference # Ensure works on device _check_hardware_support(granularity) @@ -1652,26 +1679,45 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): "PerRow quantization only works for bfloat16 precision input weight" ) - block_size = get_block_size(weight.shape[-2:], weight_granularity) - if weight.dim() == 3: - block_size = tuple([1] + list(block_size)) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) + if config.VERSION == 1: + block_size = get_block_size(weight.shape[-2:], weight_granularity) + if weight.dim() == 3: + block_size = tuple([1] + list(block_size)) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) + else: + assert config.VERSION == 2, f"Unexpected version: {config.VERSION}" + act_quant_kwargs = QuantizeTensorToFloat8Kwargs( + activation_dtype, + activation_granularity, + hp_value_lb=activation_value_lb, + hp_value_ub=activation_value_ub, + ) + + quantized_weight = Float8Tensor.to_float8( + weight, + float8_dtype=weight_dtype, + granularity=weight_granularity, + mm_config=mm_config, + kernel_preference=kernel_preference, + act_quant_kwargs=act_quant_kwargs, + ) - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs - ) return quantized_weight @@ -1760,13 +1806,9 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): granularity: Optional[ Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]] ] = None - mm_config: Optional[Float8MMConfig] = None + mm_config: Optional[Float8MMConfig] = Float8MMConfig(use_fast_accum=True) set_inductor_config: bool = True - def __post_init__(self): - if self.mm_config is None: - self.mm_config = Float8MMConfig(use_fast_accum=True) - # for bc float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig @@ -2070,7 +2112,7 @@ class FbgemmConfig(AOBaseConfig): weight_dtype: torch.dtype output_dtype: torch.dtype block_size: Optional[List[int]] = None - activation_scale_ub: Optional[float] = None + activation_scale_ub: float = 1200.0 preshuffle: bool = False diff --git a/torchao/quantization/quantize_/common/__init__.py b/torchao/quantization/quantize_/common/__init__.py new file mode 100644 index 0000000000..b6b0102d45 --- /dev/null +++ b/torchao/quantization/quantize_/common/__init__.py @@ -0,0 +1,11 @@ +from .kernel_preference import KernelPreference +from .quantize_tensor_kwargs import ( + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) + +__all__ = [ + "QuantizeTensorKwargs", + "KernelPreference", + "_choose_quant_func_and_quantize_tensor", +] diff --git a/torchao/quantization/quantize_/common/kernel_preference.py b/torchao/quantization/quantize_/common/kernel_preference.py new file mode 100644 index 0000000000..5430463543 --- /dev/null +++ b/torchao/quantization/quantize_/common/kernel_preference.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + +import torch + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class KernelPreference(str, Enum): + """Enum for specifying the groups of kernels that's used for quantization, matrix multiplication + or other compute ops for quantized tensor + + Examples of how options affects the selected kernels can be found in tensor subclass implementations under torchao/quantization/quantize_/workflows + """ + + """Use the most efficient quantize and mm kernels chosen for user based on hardware and library availabilities and versions etc. + """ + AUTO = "auto" + + """Use torch native quantize and quantized mm kernels + """ + TORCH = "torch" + + """Use fbgemm quantize and quantized mm kernels, requires fbgemm_gpu_genai library + """ + FBGEMM = "fbgemm" + + +if TORCH_VERSION_AT_LEAST_2_5: + torch.serialization.add_safe_globals([KernelPreference]) diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py new file mode 100644 index 0000000000..443ddea00e --- /dev/null +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import abc +from typing import ClassVar + +import torch + +__all__ = [ + "QuantizeTensorKwargs", + "_choose_quant_func_and_quantize_tensor", +] + + +class QuantizeTensorKwargs(abc.ABC): + """Base class for keyword argument container for quantized tensor creation. This is needed to support storing activation construction arguments on the weight tensor while supporting multiple types of activation quantization. + + e.g. + + class Float8Tensor(...) + @classmethod + def to_float8(cls, tensor, quant_kwargs: QuantizeTensorKwargs) + ... + """ + + # Base Version of a config + VERSION: ClassVar[int] = 1 + + +def _choose_quant_func_and_quantize_tensor( + tensor: torch.Tensor, quant_kwargs: QuantizeTensorKwargs +) -> torch.Tensor: + """Given a tensor and a kwargs container, chooses a derived dtype (float8, int8, etc) to quantize tensor to, based on the type of quant_kwargs + quantizes tensor to the derived dtype chosen in (1) + This is needed to support flexible quantization of activation and weights to various derived dtypes. + """ + from torchao.quantization.quantize_.workflows import ( + Float8Tensor, + QuantizeTensorToFloat8Kwargs, + ) + + if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): + return Float8Tensor.to_float8( + tensor, + quant_kwargs.float8_dtype, + quant_kwargs.granularity, + quant_kwargs.mm_config, + quant_kwargs.hp_value_lb, + quant_kwargs.hp_value_ub, + quant_kwargs.kernel_preference, + ) + + raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 40548e0e0e..2313d2695d 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,7 +1,13 @@ +from .float8.float8_tensor import ( + Float8Tensor, + QuantizeTensorToFloat8Kwargs, +) from .int4.int4_preshuffled_tensor import ( Int4PreshuffledTensor, ) __all__ = [ "Int4PreshuffledTensor", + "Float8Tensor", + "QuantizeTensorToFloat8Kwargs", ] diff --git a/torchao/quantization/quantize_/workflows/float8/__init__.py b/torchao/quantization/quantize_/workflows/float8/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py new file mode 100644 index 0000000000..611c476b76 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -0,0 +1,613 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import List, Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.dtypes.utils import get_out_shape +from torchao.float8.inference import ( + Float8MMConfig, + FP8Granularity, + _is_rowwise_scaled, + _is_tensorwise_scaled, + _slice_scale_for_dimension, + addmm_float8_unwrapped_inference, + preprocess_data, + preprocess_scale, +) +from torchao.quantization.granularity import PerRow +from torchao.quantization.observer import get_block_size +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _dequantize_affine_float8, + _quantize_affine_float8, +) +from torchao.quantization.quantize_.common import ( + KernelPreference, + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, + _is_fbgemm_genai_gpu_available, + fill_defaults, + is_sm_at_least_90, +) + +__all__ = [ + "Float8Tensor", + "QuantizeTensorToFloat8Kwargs", +] + +aten = torch.ops.aten + + +@dataclass +class QuantizeTensorToFloat8Kwargs(QuantizeTensorKwargs): + """Tensor kwargs for creating float8 tensor (either activation or weight) + + Args: + dtype (torch.dtype): the dtype for float8 Tensor + granularity (FP8Granularity): the granularity for the Tensor, currently either PerRow() or PerTensor() + mm_config (Float8MMConfig): Configuration for the scaled_mm in the forward and backward pass. + hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale + hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale + kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. by defalut (None) it will be chosen for user based on hardware or other information + """ + + float8_dtype: torch.dtype = torch.float8_e4m3fn + granularity: FP8Granularity = PerRow() + mm_config: Optional[Float8MMConfig] = None + hp_value_lb: Optional[float] = None + hp_value_ub: Optional[float] = None + kernel_preference: KernelPreference = KernelPreference.AUTO + + +class Float8Tensor(TorchAOBaseTensor): + """ + Float8 Quantized (weight) Tensor, with float8 dynamic quantization for activation or bfloat16 activation. + + TODO: needs padding for cutlass kernels + + Tensor Attributes: + qdata: float8 raw data + scale: the scale for float8 Tensor + + Non-Tensor Attributes: + block_size (List[int]): the block size for float8 quantization, meaning the shape of the elements + sharing the same set of quantization parameters (scale), have the same rank as qdata or + is an empty list (representing per tensor quantization) + mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale + hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale + act_quant_kwargs (QuantizeTensorToFloat8Kwargs): the kwargs for Float8Tensor.to_float8 + kernel_preference (KernelPreference): the preference for quantize, mm etc. kernel to use, + by default, this will be chosen for user based on hardware, library availabilities etc. + dtype: Original Tensor dtype + """ + + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = [ + "block_size", + "mm_config", + "hp_value_lb", + "hp_value_ub", + "act_quant_kwargs", + "kernel_preference", + "dtype", + ] + + def __new__( + cls, + qdata, + scale, + block_size, + mm_config, + hp_value_lb, + hp_value_ub, + act_quant_kwargs, + kernel_preference, + dtype, + ): + shape = qdata.shape + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: Optional[List[int]] = None, + mm_config: Optional[Float8MMConfig] = None, + hp_value_lb: Optional[float] = None, + hp_value_ub: Optional[float] = None, + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + kernel_preference: KernelPreference = KernelPreference.AUTO, + dtype: Optional[torch.dtype] = None, + ): + self.qdata = qdata + self.scale = scale + self.block_size = block_size + self.mm_config = mm_config + self.hp_value_lb = hp_value_lb + self.hp_value_ub = hp_value_ub + self.act_quant_kwargs = act_quant_kwargs + self.kernel_preference = kernel_preference + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " + f"{self.block_size=}, {self.mm_config=}, " + f"{self.shape=}, {self.device=}, {self.dtype=})" + ) + + def _quantization_type(self): + return f"{self.act_quant_kwargs=}, {self.block_size=}, {self.mm_config=}, {self.scale.shape=}" + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + qdata, scale = self.qdata, self.scale + return _dequantize_affine_float8(qdata, scale, output_dtype) + + @classmethod + def to_float8( + cls, + hp_tensor: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, + granularity: FP8Granularity = PerRow(), + mm_config: Optional[Float8MMConfig] = None, + hp_value_lb: Optional[float] = None, + hp_value_ub: Optional[float] = None, + kernel_preference: KernelPreference = KernelPreference.AUTO, + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + ): + block_size = get_block_size(hp_tensor.shape, granularity) + block_size = list(block_size) + + # for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance + if ( + kernel_preference == KernelPreference.AUTO + and _is_fbgemm_genai_gpu_available() + and ( + tuple(block_size) + == (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],) + ) + ): + assert float8_dtype == torch.float8_e4m3fn, ( + f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}" + ) + if hp_value_ub is not None: + maybe_hp_value_ub_tensor = torch.tensor( + hp_value_ub, dtype=torch.float, device=hp_tensor.device + ) + else: + maybe_hp_value_ub_tensor = None + data, scale = torch.ops.triton.quantize_fp8_row( + hp_tensor, scale_ub=maybe_hp_value_ub_tensor + ) + scale_shape = [] + for i in range(hp_tensor.ndim): + scale_shape.append(hp_tensor.shape[i] // block_size[i]) + scale = scale.reshape(*scale_shape) + else: + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=float8_dtype, + block_size=block_size, + hp_value_lb=hp_value_lb, + hp_value_ub=hp_value_ub, + ) + data = _quantize_affine_float8(hp_tensor, scale, float8_dtype) + + hp_dtype = hp_tensor.dtype + return Float8Tensor( + data, + scale, + block_size=block_size, + mm_config=mm_config, + hp_value_lb=hp_value_lb, + hp_value_ub=hp_value_ub, + act_quant_kwargs=act_quant_kwargs, + kernel_preference=kernel_preference, + dtype=hp_dtype, + ) + + +implements = Float8Tensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert isinstance(weight_tensor, Float8Tensor), ( + f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" + ) + + act_quant_kwargs = weight_tensor.act_quant_kwargs + # quantizing activation, if `act_quant_kwargs` is specified + if act_quant_kwargs is not None: + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, act_quant_kwargs + ) + + if isinstance(input_tensor, Float8Tensor): + kernel_choice = None + + if weight_tensor.kernel_preference == KernelPreference.AUTO: + kernel_choice = "torch" + if _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(): + kernel_choice = "fbgemm" + elif weight_tensor.kernel_preference == KernelPreference.FBGEMM: + kernel_choice = "fbgemm" + else: + assert weight_tensor.kernel_preference == KernelPreference.TORCH, ( + f"{weight_tensor.kernel_preference=} not handled" + ) + kernel_choice = "torch" + + if kernel_choice == "fbgemm": + assert _is_fbgemm_genai_gpu_available(), ( + "Expected fbgemm_gpu_genai package to be installed" + ) + assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai" + + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1]) + wq = weight_tensor.qdata.contiguous() + x_scale = input_tensor.scale + w_scale = weight_tensor.scale + if _is_rowwise_scaled(weight_tensor): + assert _is_rowwise_scaled(input_tensor), ( + "Input tensor must be rowwise block size" + ) + res = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, + wq, + x_scale, + w_scale, + ).reshape(out_shape) + else: + assert _is_tensorwise_scaled(weight_tensor) + assert _is_tensorwise_scaled(input_tensor) + res = torch.ops.fbgemm.f8f8bf16( + xq, + wq, + x_scale * w_scale, + ).reshape(out_shape) + if bias is not None: + res = res + bias + return res + else: + assert kernel_choice == "torch" + scaled_mm_config = weight_tensor.mm_config + assert scaled_mm_config is not None + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + + # Extract tensor data and scales + inpt_data = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1]) + w_data = weight_tensor.qdata + input_scale = input_tensor.scale + w_scale = weight_tensor.scale + + # Handle rowwise scaling + if _is_rowwise_scaled(weight_tensor): + assert _is_rowwise_scaled(input_tensor), ( + "Input tensor must be rowwise block size" + ) + w_scale = w_scale.transpose(-1, -2) + + input_scale = preprocess_scale(input_scale, input_tensor.shape) + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=input_tensor.dtype, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) + else: + assert not isinstance(input_tensor, TorchAOBaseTensor), ( + "Expecting input_tensor to be unquantized" + ) + # when input is not `Float8Tensor`, we expect that it is not quantized + # so this is float8 weight only quantization + return torch.nn.functional.linear( + input_tensor, weight_tensor.dequantize(), bias + ) + + +@implements(torch.bmm) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + assert isinstance(weight_tensor, Float8Tensor), ( + f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" + ) + + kernel_preference = weight_tensor.kernel_preference + assert kernel_preference != KernelPreference.TORCH, "bmm is not supported for TORCH" + assert _is_fbgemm_genai_gpu_available(), ( + "bmm is not supported when fbgemm_gpu_genai is not installed" + ) + + orig_act_size = input_tensor.size() + act_quant_kwargs = weight_tensor.act_quant_kwargs + if act_quant_kwargs is not None: + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, act_quant_kwargs + ) + + if isinstance(input_tensor, Float8Tensor): + a_data = input_tensor.qdata + a_scale = input_tensor.scale + + b_data = weight_tensor.qdata + b_scale = weight_tensor.scale.squeeze(-1) + assert b_data.is_contiguous(), "weight for bmm must be contiguous" + + assert ( + all(x == 1 for x in weight_tensor.block_size[:-1]) + and weight_tensor.block_size[-1] == weight_tensor.shape[-1] + ), "bmm only works for per row weight quantization" + assert ( + all(x == 1 for x in input_tensor.block_size[:-1]) + and input_tensor.block_size[-1] == input_tensor.shape[-1] + ), "bmm only works for per row activation quantization" + + orig_out_features = b_data.shape[-2] + + res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( + a_data, + b_data, + a_scale, + b_scale, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + else: + raise NotImplementedError( + "bmm only support float8 dynamic activation + float8 weight" + ) + + return res + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Only supports slicing for dim == 1 and dim == 2 + original tensor shape has dimension (N, K) + qdata has dimension (N, K) + scale (per row quantization) has dimension: (N,) + + since qdata has the same dimension as original tensor, we can directly slice that + for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1 + + Note that we need to call slice on the qdata and scale directly because slice + is an operation that need to preserve aliasing + """ + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + + assert self.qdata.ndim == 2, ( + f"Expected packed weight to have dim 2, got {self.qdata.dim}" + ) + + # Always slice the qdata + sliced_data = aten.slice.Tensor(self.qdata, dim, start, end, step) + + if self.scale.numel() == 1: + # Per-tensor quantization - scale doesn't change + sliced_scale = self.scale + else: + # Block-wise quantization - need to slice the scale appropriately + sliced_scale = _slice_scale_for_dimension( + self.scale, self.qdata.shape, dim, start, end, step + ) + + # adjust block_size since the shape has changed, block_size[i] should not be greater than shape[i] + block_size = self.block_size.copy() + for i in range(len(self.block_size)): + block_size[i] = min(block_size[i], sliced_data.shape[i]) + + return return_and_correct_aliasing( + func, + args, + kwargs, + Float8Tensor( + sliced_data, + sliced_scale, + block_size, + self.mm_config, + self.hp_value_lb, + self.hp_value_ub, + self.act_quant_kwargs, + self.kernel_preference, + dtype=self.dtype, + ), + ) + + +@implements(aten.cat.default) +def _(func, types, args, kwargs): + """Concatenate multiple float8 quantized tensors + (scale and qdata has the same rank) + If the concatenation dimension is not the same as block_size, then we can just concatenate the + qdata and scale directly + If the concatention dimension is the same as block_size, theoretically we should either + (1) check that scales from all tensors are equal and use the first scale + (2) dequantize and requantize + but for now we just use the first scale directly, which might have slight implication on accuaracy + we can improve upon this a bit later + """ + + tensors, dim = fill_defaults(args, 2, [[], 0]) + tensor_0 = tensors[0] + dim = dim % tensor_0.ndim + + for i in range(1, len(tensors)): + assert tensor_0.qdata.ndim == tensors[i].qdata.ndim + assert tensor_0.scale.ndim == tensors[i].scale.ndim + assert tensor_0.block_size == tensors[i].block_size + assert tensor_0.mm_config == tensors[i].mm_config + assert tensor_0.hp_value_lb == tensors[i].hp_value_lb + assert tensor_0.hp_value_ub == tensors[i].hp_value_ub + assert tensor_0.act_quant_kwargs == tensors[i].act_quant_kwargs + assert tensor_0.kernel_preference == tensors[i].kernel_preference + + qdatas = [t.qdata for t in tensors] + scales = [t.scale for t in tensors] + + cat_qdata = aten.cat.default(qdatas, dim=dim) + if tensor_0.block_size[dim] == 1: + cat_scale = aten.cat.default(scales, dim=dim) + else: + for i in range(1, len(tensors)): + assert torch.equal(tensor_0.scale, tensors[i].scale) + cat_scale = scales[0] + + block_size = [] + for i in range(cat_qdata.ndim): + block_size.append(cat_qdata.shape[i] // cat_scale.shape[i]) + + new = tensor_0.__class__( + cat_qdata, + cat_scale, + block_size, + tensor_0.mm_config, + tensor_0.hp_value_lb, + tensor_0.hp_value_ub, + tensor_0.act_quant_kwargs, + tensor_0.kernel_preference, + tensor_0.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + self, dim0, dim1 = args + qdata = self.qdata.transpose(dim0, dim1).contiguous() + scale = self.scale.transpose(dim0, dim1).contiguous() + block_size = self.block_size.copy() + + block_size[dim0], block_size[dim1] = block_size[dim1], block_size[dim0] + + new = self.__class__( + qdata, + scale, + block_size, + self.mm_config, + self.hp_value_lb, + self.hp_value_ub, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.view.default) +def _(func, types, args, kwargs): + self, size = args + original_shape = self.shape + if len(original_shape) == 3 and len(size) == 2: + assert original_shape[-1] == size[-1], ( + f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + ) + qdata = self.qdata.reshape(*size) + scale = self.scale.reshape(*size) + block_size = self.block_size.copy() + block_size = [block_size[0] * block_size[1], block_size[2]] + elif len(original_shape) == 2 and len(size) == 3: + assert original_shape[-1] == size[-1], ( + f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + ) + qdata = self.qdata.reshape(*size) + block_size = self.block_size.copy() + block_size = [1, block_size[0], block_size[1]] + scale_shape = [] + for i in range(3): + scale_shape.append(qdata.shape[i] // block_size[i]) + scale = self.scale.reshape(*scale_shape) + elif len(original_shape) == len(size): + assert all(x == y or y == -1 for x, y in zip(original_shape, size)), ( + f"Only support viewing with match dimensions or -1, got: {original_shape}, {size}" + ) + qdata = self.qdata.reshape(*size) + scale_shape = [] + for i in range(3): + scale_shape.append(qdata.shape[i] // self.block_size[i]) + scale = self.scale.reshape(*scale_shape) + block_size = self.block_size + else: + assert len(original_shape) == 2 and len(size) == 3, ( + f"Only support reshaping from 2D to 3D or from 3D to 2D, requested: reshaping from {original_shape} to {size}" + ) + + new = self.__class__( + qdata, + scale, + block_size, + self.mm_config, + self.hp_value_lb, + self.hp_value_ub, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.squeeze.dim) +def _(func, types, args, kwargs): + self, dim = args + assert dim == 0, f"Only dim == 0 is supported, got: {dim}" + qdata = self.qdata.squeeze(dim=dim) + scale = self.scale.squeeze(dim=dim) + block_size = [] + for i in range(len(qdata.shape)): + block_size.append(qdata.shape[i] // scale.shape[i]) + + new = self.__class__( + qdata, + scale, + block_size, + self.mm_config, + self.hp_value_lb, + self.hp_value_ub, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +Float8Tensor.__module__ = "torchao.quantization" + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with Float8Tensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs])