Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions benchmarks/prototype/moe_training/mxfp8/bench_dequantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.mx_formats.kernels import triton_mxfp8_dequant_dim0
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
input_shape: tuple[int]


@dataclass(frozen=True)
class ExperimentResult:
# time
torch_us: float
triton_us: float
torch_gbps: float
triton_gbps: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
input_shapes = [
# (local_batch_size, seq_len, dim)
(1, 8192, 7168),
(2, 8192, 7168),
(4, 8192, 7168),
(8, 8192, 7168),
]
configs = []
for shape in input_shapes:
configs.append(
ExperimentConfig(
input_shape=shape,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
block_size = 32
input_shape = config.input_shape
input_tensor = torch.randn(
*input_shape,
dtype=torch.bfloat16,
device=device,
)

e8m0_scales, e4m3_data = to_mx(input_tensor, torch.float8_e4m3fn, block_size)

# Bench torch dequant
to_dtype_c = torch.compile(to_dtype)
elem_dtype, target_dtype = torch.float8_e4m3fn, torch.bfloat16
torch_output = to_dtype_c(
e4m3_data,
e8m0_scales,
elem_dtype,
block_size,
target_dtype,
)
torch_us = benchmark_cuda_function_in_microseconds(
to_dtype_c,
e4m3_data,
e8m0_scales,
elem_dtype,
block_size,
target_dtype,
)

# Bench triton kernel
_ = triton_mxfp8_dequant_dim0(
e4m3_data,
e8m0_scales,
target_dtype,
block_size,
)
triton_us = benchmark_cuda_function_in_microseconds(
triton_mxfp8_dequant_dim0,
e4m3_data,
e8m0_scales,
target_dtype,
block_size,
)

# mem bw calculations
bytes_per_input_el = torch.finfo(elem_dtype).bits / 8
bytes_per_output_el = torch.finfo(target_dtype).bits / 8
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8

read_bytes = (
e4m3_data.numel() * bytes_per_input_el
+ e8m0_scales.numel() * bytes_per_scale_el
)
write_bytes = torch_output.numel() * bytes_per_output_el

torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_us / 1e6)
triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_us / 1e6)

return ExperimentResult(
torch_us=torch_us,
triton_us=triton_us,
triton_gbps=triton_gbps,
torch_gbps=torch_gbps,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"torch_us",
"triton_us",
"torch_gbps",
"triton_gbps",
"triton_speedup",
]
rows = []
for experiment in experiments:
triton_speedup = round(
experiment.result.torch_us / experiment.result.triton_us, 3
)
rows.append(
[
str(experiment.config.input_shape),
experiment.result.torch_us,
experiment.result.triton_us,
round(experiment.result.torch_gbps, 3),
round(experiment.result.triton_gbps, 3),
f"{triton_speedup}x",
]
)
print(tabulate(rows, headers=headers))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
25 changes: 24 additions & 1 deletion test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@
pack_uint6,
triton_f6_e2m3_to_bf16,
triton_f6_e3m2_to_bf16,
triton_mxfp8_dequant_dim0,
triton_to_mxfp8_dim0,
triton_to_mxfp8_dim1,
triton_to_mxfp8_dim1_reference,
unpack_uint4,
)
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
is_sm_at_least_89,
Expand Down Expand Up @@ -513,6 +514,28 @@ def test_triton_mxfp8_dim0_zeros():
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="mxfp8 requires CUDA capability 10.0 or greater",
)
@pytest.mark.parametrize("M", (256, 2048, 131072))
@pytest.mark.parametrize("K", (256, 5120, 7168))
def test_triton_mxfp8_dequant_dim0(M, K):
x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda")
block_size = 32
x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32)
hp_ref = to_dtype(
x_data,
x_scales,
torch.float8_e4m3fn,
block_size,
torch.bfloat16,
)
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size)
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, didn't look at the rest too closely



@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"shape",
Expand Down
131 changes: 120 additions & 11 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,17 +880,19 @@ def _get_mxfp8_quant_autotune_configs():
# sweep over a small set of shapes, it's likely that this
# can be improved in the future.
results = []
for ROW_TILE_SIZE in (64, 128):
for COL_TILE_SIZE in (64, 128):
for num_warps in (1, 2, 4):
config = triton.Config(
{
"ROW_TILE_SIZE": ROW_TILE_SIZE,
"COL_TILE_SIZE": COL_TILE_SIZE,
},
num_warps=num_warps,
)
results.append(config)
for ROW_TILE_SIZE in (128, 256, 512):
for COL_TILE_SIZE in (128, 256, 512):
for num_warps in (4, 8):
for num_stages in (2, 3):
config = triton.Config(
{
"ROW_TILE_SIZE": ROW_TILE_SIZE,
"COL_TILE_SIZE": COL_TILE_SIZE,
},
num_warps=num_warps,
num_stages=num_stages,
)
results.append(config)
return results

@triton.autotune(
Expand Down Expand Up @@ -1277,6 +1279,105 @@ def triton_to_mxfp8_dim1_reference(
scale_e8m0_dim1.unsqueeze(-1),
)

def triton_mxfp8_dequant_dim0(
e4m3_data: torch.Tensor,
e8m0_scales: torch.Tensor,
out_dtype: torch.dtype,
scale_block_size: int = 32,
) -> None:
assert scale_block_size == 32, "scale_block_size must be 32 for now"
assert out_dtype in (torch.bfloat16, torch.float32), (
"out_dtype must be bf16 or fp32"
)

# Input shape must be 2D.
orig_shape = e4m3_data.shape
e4m3_data = e4m3_data.reshape(-1, orig_shape[-1])
out_buffer = torch.empty_like(e4m3_data, dtype=out_dtype)
out_dtype_tl = tl.bfloat16 if out_dtype == torch.bfloat16 else tl.float32

grid = lambda META: (
triton.cdiv(e4m3_data.shape[0], META["ROW_TILE_SIZE"]),
triton.cdiv(e4m3_data.shape[1], META["COL_TILE_SIZE"]),
)
_dequant_mxfp8_kernel[grid](
e4m3_data,
e8m0_scales.to(torch.uint8),
out_buffer,
e4m3_data.size(0),
e4m3_data.size(1),
e8m0_scales.size(0),
e8m0_scales.size(1),
out_dtype=out_dtype_tl,
SCALE_BLOCK_SIZE=scale_block_size,
)
return out_buffer.reshape(orig_shape)

@triton.autotune(
configs=_get_mxfp8_quant_autotune_configs(),
key=["input_num_cols", "SCALE_BLOCK_SIZE"],
)
@triton.jit
def _dequant_mxfp8_kernel(
e4m3_data,
e8m0_scales,
out_buffer,
input_num_rows,
input_num_cols,
scale_num_rows,
scale_num_cols,
out_dtype: tl.constexpr,
SCALE_BLOCK_SIZE: tl.constexpr,
ROW_TILE_SIZE: tl.constexpr,
COL_TILE_SIZE: tl.constexpr,
):
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)
SCALE_BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // SCALE_BLOCK_SIZE

# Load block of e4m3 data
row_offs = pid_row * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)
col_offs = pid_col * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE)
block_offs = row_offs[:, None] * input_num_cols + col_offs[None, :]
mask = (row_offs[:, None] < input_num_rows) & (
col_offs[None, :] < input_num_cols
)
e4m3_data_block = tl.load(e4m3_data + block_offs, mask=mask)

# Load block of e8m0 scales
scale_col_offs = pid_col * SCALE_BLOCKS_PER_COL_TILE + tl.arange(
0, SCALE_BLOCKS_PER_COL_TILE
)
scale_block_offs = row_offs[:, None] * scale_num_cols + scale_col_offs[None, :]
scale_mask = (row_offs[:, None] < scale_num_rows) & (
scale_col_offs[None, :] < scale_num_cols
)
e8m0_scale_block = tl.load(e8m0_scales + scale_block_offs, mask=scale_mask)

# Dequantize and return output
e4m3_data_block_r = e4m3_data_block.reshape(
ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE, SCALE_BLOCK_SIZE
)
e8m0_scale_block_r = e8m0_scale_block.reshape(
ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE, 1
)
fp32_scale = _e8m0_to_fp32(e8m0_scale_block_r)
data_hp = e4m3_data_block_r.to(tl.float32) * fp32_scale

# Write to output buffer
out_buffer_block = data_hp.to(out_dtype)
out_buffer_block = out_buffer_block.reshape(ROW_TILE_SIZE, COL_TILE_SIZE)
tl.store(out_buffer + block_offs, out_buffer_block, mask=mask)

@triton.jit
def _e8m0_to_fp32(scale_e8m0):
e8m0_exponent_bias = 127
e8m0_nan_val = 255
s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias
s_fp = tl.exp2(s_offset.to(tl.float32))
s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))
return s_fp.to(tl.float32)

@triton.jit
def triton_scale_swizzle(
scale_ptr,
Expand Down Expand Up @@ -1641,6 +1742,14 @@ def triton_quantize_nvfp4(
) -> Tuple[torch.Tensor, torch.Tensor]:
raise AssertionError("needs torch version 2.8+ and triton")

def triton_mxfp8_dequant_dim0(
e4m3_data: torch.Tensor,
e8m0_scales: torch.Tensor,
out_dtype: torch.dtype,
inner_block_size=32,
) -> torch.Tensor:
raise AssertionError("needs torch version 2.8+ and triton")


mxfp8_cuda_extension_available = False
if is_sm_at_least_100():
Expand Down
Loading