Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 81 additions & 25 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@
)

import torchao
from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
)
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
PerRow,
Expand Down Expand Up @@ -80,40 +88,67 @@ def get_gemm_times(
fast_accum: bool,
recipe_name: Optional[str],
):
assert recipe_name in {"rowwise"}, (
"Only support real benchmarks for 'rowwise' recipe for now"
)
device = torch.device("cuda")

# bf16 time
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device)
# w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)

bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)

e4m3_dtype = torch.float8_e4m3fn
if torch.version.hip and torch.cuda.is_available() and is_MI300():
e4m3_dtype = torch.float8_e4m3fnuz
d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16
A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1)
B = (
torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8)
.view(d2)
.t()
.contiguous()
.t()
)
if recipe_name in ("mxfp4_cutlass", "nvfp4"):
d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16
A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view(
d1
)
B = (
torch.randint(0, 255, (K // 2, N), device=device, dtype=torch.uint8)
.t()
.contiguous()
.t()
.view(d2)
)
else:
e4m3_dtype = torch.float8_e4m3fn
if torch.version.hip and torch.cuda.is_available() and is_MI300():
e4m3_dtype = torch.float8_e4m3fnuz
d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16
A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1)
B = (
torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8)
.view(d2)
.t()
.contiguous()
.t()
)

if recipe_name == "rowwise":
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
elif recipe_name == "mxfp8_cublas":
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
elif recipe_name == "mxfp4_cutlass":
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
elif recipe_name == "nvfp4":
scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn)
scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn)

else:
assert False, "unsupported"

def do_matmul(A, B):
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)
if recipe_name == "mxfp4_cutlass":
return torchao.ops.mx_fp4_bf16(A, B, scale_a, scale_b)
if recipe_name == "nvfp4":
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
)
else:
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

Expand Down Expand Up @@ -259,12 +294,33 @@ def run(
# get the float8 dynamic scaling gpu kernel time
torch._dynamo.reset()

config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
# for now, use TORCH. In the future might be interesting
# to benchmark AUTO and FBGEMM.
kernel_preference=KernelPreference.TORCH,
)
if recipe_name == "rowwise":
config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
# for now, use TORCH. In the future might be interesting
# to benchmark AUTO and FBGEMM.
kernel_preference=KernelPreference.TORCH,
)
elif recipe_name == "mxfp8_cublas":
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
)
elif recipe_name == "mxfp4_cutlass":
config = MXFPInferenceConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
)
elif recipe_name == "nvfp4":
config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.DYNAMIC,
use_dynamic_per_tensor_scale=False,
)
else:
assert False, "unsupported"

m_fp8_dyn = copy.deepcopy(m_orig)
quantize_(m_fp8_dyn, config)

Expand Down
Loading