diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_dequantize.py b/benchmarks/prototype/moe_training/mxfp8/bench_dequantize.py new file mode 100644 index 0000000000..3b38fc2242 --- /dev/null +++ b/benchmarks/prototype/moe_training/mxfp8/bench_dequantize.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm + +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.mx_formats.kernels import triton_mxfp8_dequant_dim0 +from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: tuple[int] + + +@dataclass(frozen=True) +class ExperimentResult: + # time + torch_us: float + triton_us: float + torch_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + input_shapes = [ + # (local_batch_size, seq_len, dim) + (1, 8192, 7168), + (2, 8192, 7168), + (4, 8192, 7168), + (8, 8192, 7168), + ] + configs = [] + for shape in input_shapes: + configs.append( + ExperimentConfig( + input_shape=shape, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + block_size = 32 + input_shape = config.input_shape + input_tensor = torch.randn( + *input_shape, + dtype=torch.bfloat16, + device=device, + ) + + e8m0_scales, e4m3_data = to_mx(input_tensor, torch.float8_e4m3fn, block_size) + + # Bench torch dequant + to_dtype_c = torch.compile(to_dtype) + elem_dtype, target_dtype = torch.float8_e4m3fn, torch.bfloat16 + torch_output = to_dtype_c( + e4m3_data, + e8m0_scales, + elem_dtype, + block_size, + target_dtype, + ) + torch_us = benchmark_cuda_function_in_microseconds( + to_dtype_c, + e4m3_data, + e8m0_scales, + elem_dtype, + block_size, + target_dtype, + ) + + # Bench triton kernel + _ = triton_mxfp8_dequant_dim0( + e4m3_data, + e8m0_scales, + target_dtype, + block_size, + ) + triton_us = benchmark_cuda_function_in_microseconds( + triton_mxfp8_dequant_dim0, + e4m3_data, + e8m0_scales, + target_dtype, + block_size, + ) + + # mem bw calculations + bytes_per_input_el = torch.finfo(elem_dtype).bits / 8 + bytes_per_output_el = torch.finfo(target_dtype).bits / 8 + bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 + + read_bytes = ( + e4m3_data.numel() * bytes_per_input_el + + e8m0_scales.numel() * bytes_per_scale_el + ) + write_bytes = torch_output.numel() * bytes_per_output_el + + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_us / 1e6) + + return ExperimentResult( + torch_us=torch_us, + triton_us=triton_us, + triton_gbps=triton_gbps, + torch_gbps=torch_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape", + "torch_us", + "triton_us", + "torch_gbps", + "triton_gbps", + "triton_speedup", + ] + rows = [] + for experiment in experiments: + triton_speedup = round( + experiment.result.torch_us / experiment.result.triton_us, 3 + ) + rows.append( + [ + str(experiment.config.input_shape), + experiment.result.torch_us, + experiment.result.triton_us, + round(experiment.result.torch_gbps, 3), + round(experiment.result.triton_gbps, 3), + f"{triton_speedup}x", + ] + ) + print(tabulate(rows, headers=headers)) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index abee8b2ff9..240e8eea49 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -37,12 +37,13 @@ pack_uint6, triton_f6_e2m3_to_bf16, triton_f6_e3m2_to_bf16, + triton_mxfp8_dequant_dim0, triton_to_mxfp8_dim0, triton_to_mxfp8_dim1, triton_to_mxfp8_dim1_reference, unpack_uint4, ) -from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx +from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( is_sm_at_least_89, @@ -513,6 +514,28 @@ def test_triton_mxfp8_dim0_zeros(): torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="mxfp8 requires CUDA capability 10.0 or greater", +) +@pytest.mark.parametrize("M", (256, 2048, 131072)) +@pytest.mark.parametrize("K", (256, 5120, 7168)) +def test_triton_mxfp8_dequant_dim0(M, K): + x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda") + block_size = 32 + x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32) + hp_ref = to_dtype( + x_data, + x_scales, + torch.float8_e4m3fn, + block_size, + torch.bfloat16, + ) + hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size) + torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize( "shape", diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 7053225521..31d9e96f4b 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -880,17 +880,19 @@ def _get_mxfp8_quant_autotune_configs(): # sweep over a small set of shapes, it's likely that this # can be improved in the future. results = [] - for ROW_TILE_SIZE in (64, 128): - for COL_TILE_SIZE in (64, 128): - for num_warps in (1, 2, 4): - config = triton.Config( - { - "ROW_TILE_SIZE": ROW_TILE_SIZE, - "COL_TILE_SIZE": COL_TILE_SIZE, - }, - num_warps=num_warps, - ) - results.append(config) + for ROW_TILE_SIZE in (128, 256, 512): + for COL_TILE_SIZE in (128, 256, 512): + for num_warps in (4, 8): + for num_stages in (2, 3): + config = triton.Config( + { + "ROW_TILE_SIZE": ROW_TILE_SIZE, + "COL_TILE_SIZE": COL_TILE_SIZE, + }, + num_warps=num_warps, + num_stages=num_stages, + ) + results.append(config) return results @triton.autotune( @@ -1277,6 +1279,105 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1.unsqueeze(-1), ) + def triton_mxfp8_dequant_dim0( + e4m3_data: torch.Tensor, + e8m0_scales: torch.Tensor, + out_dtype: torch.dtype, + scale_block_size: int = 32, + ) -> None: + assert scale_block_size == 32, "scale_block_size must be 32 for now" + assert out_dtype in (torch.bfloat16, torch.float32), ( + "out_dtype must be bf16 or fp32" + ) + + # Input shape must be 2D. + orig_shape = e4m3_data.shape + e4m3_data = e4m3_data.reshape(-1, orig_shape[-1]) + out_buffer = torch.empty_like(e4m3_data, dtype=out_dtype) + out_dtype_tl = tl.bfloat16 if out_dtype == torch.bfloat16 else tl.float32 + + grid = lambda META: ( + triton.cdiv(e4m3_data.shape[0], META["ROW_TILE_SIZE"]), + triton.cdiv(e4m3_data.shape[1], META["COL_TILE_SIZE"]), + ) + _dequant_mxfp8_kernel[grid]( + e4m3_data, + e8m0_scales.to(torch.uint8), + out_buffer, + e4m3_data.size(0), + e4m3_data.size(1), + e8m0_scales.size(0), + e8m0_scales.size(1), + out_dtype=out_dtype_tl, + SCALE_BLOCK_SIZE=scale_block_size, + ) + return out_buffer.reshape(orig_shape) + + @triton.autotune( + configs=_get_mxfp8_quant_autotune_configs(), + key=["input_num_cols", "SCALE_BLOCK_SIZE"], + ) + @triton.jit + def _dequant_mxfp8_kernel( + e4m3_data, + e8m0_scales, + out_buffer, + input_num_rows, + input_num_cols, + scale_num_rows, + scale_num_cols, + out_dtype: tl.constexpr, + SCALE_BLOCK_SIZE: tl.constexpr, + ROW_TILE_SIZE: tl.constexpr, + COL_TILE_SIZE: tl.constexpr, + ): + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + SCALE_BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // SCALE_BLOCK_SIZE + + # Load block of e4m3 data + row_offs = pid_row * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE) + col_offs = pid_col * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE) + block_offs = row_offs[:, None] * input_num_cols + col_offs[None, :] + mask = (row_offs[:, None] < input_num_rows) & ( + col_offs[None, :] < input_num_cols + ) + e4m3_data_block = tl.load(e4m3_data + block_offs, mask=mask) + + # Load block of e8m0 scales + scale_col_offs = pid_col * SCALE_BLOCKS_PER_COL_TILE + tl.arange( + 0, SCALE_BLOCKS_PER_COL_TILE + ) + scale_block_offs = row_offs[:, None] * scale_num_cols + scale_col_offs[None, :] + scale_mask = (row_offs[:, None] < scale_num_rows) & ( + scale_col_offs[None, :] < scale_num_cols + ) + e8m0_scale_block = tl.load(e8m0_scales + scale_block_offs, mask=scale_mask) + + # Dequantize and return output + e4m3_data_block_r = e4m3_data_block.reshape( + ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE, SCALE_BLOCK_SIZE + ) + e8m0_scale_block_r = e8m0_scale_block.reshape( + ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE, 1 + ) + fp32_scale = _e8m0_to_fp32(e8m0_scale_block_r) + data_hp = e4m3_data_block_r.to(tl.float32) * fp32_scale + + # Write to output buffer + out_buffer_block = data_hp.to(out_dtype) + out_buffer_block = out_buffer_block.reshape(ROW_TILE_SIZE, COL_TILE_SIZE) + tl.store(out_buffer + block_offs, out_buffer_block, mask=mask) + + @triton.jit + def _e8m0_to_fp32(scale_e8m0): + e8m0_exponent_bias = 127 + e8m0_nan_val = 255 + s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias + s_fp = tl.exp2(s_offset.to(tl.float32)) + s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan")) + return s_fp.to(tl.float32) + @triton.jit def triton_scale_swizzle( scale_ptr, @@ -1641,6 +1742,14 @@ def triton_quantize_nvfp4( ) -> Tuple[torch.Tensor, torch.Tensor]: raise AssertionError("needs torch version 2.8+ and triton") + def triton_mxfp8_dequant_dim0( + e4m3_data: torch.Tensor, + e8m0_scales: torch.Tensor, + out_dtype: torch.dtype, + inner_block_size=32, + ) -> torch.Tensor: + raise AssertionError("needs torch version 2.8+ and triton") + mxfp8_cuda_extension_available = False if is_sm_at_least_100():