Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions tritonbench/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,27 @@ def __init__(

def get_input_iter(self):
def args(m, n, k):
a = torch.randn(m, k, device=self.device).to(torch.float8_e4m3fn)
a = torch.randn(m, k, device=self.device).to(torch.float16)
b = (
torch.randn(k, n, device=self.device)
.to(torch.float8_e4m3fn)
.to(torch.float16)
.T.contiguous()
.T
)
return (a, b)

if self.extra_args.scaling_rowwise:
M, N = a.shape[0], b.shape[1]
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
else:
scale_a = torch.tensor(1.0, device=a.device)
scale_b = torch.tensor(1.0, device=a.device)

# Kernels expect dtype=float8_e4m3fn
a = a.to(torch.float8_e4m3fn)
b = b.to(torch.float8_e4m3fn)

return (a, b, scale_a, scale_b)

if (
hasattr(self, "external_shapes") and self.external_shapes
Expand All @@ -90,62 +103,49 @@ def args(m, n, k):
yield args(m, n, k)

def get_x_val(self, example_inputs) -> float:
a, b = example_inputs
a, b, _, _ = example_inputs
m, k = a.size()
_, n = b.size()
return (m, n, k)

@register_benchmark(baseline=True)
def torch_fp8_gemm(self, a, b):
def _get_out_dtype(self):
if self.extra_args.scaling_rowwise:
M, N = a.shape[0], b.shape[1]
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
out_dtype = torch.bfloat16
return torch.bfloat16
else:
scale_a = torch.tensor(1.0, device=a.device)
scale_b = torch.tensor(1.0, device=a.device)
out_dtype = torch.float16
return torch.float16

@register_benchmark(baseline=True)
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
return lambda: torch._scaled_mm(
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
)

@register_benchmark()
def pt2_fp8_gemm(self, a, b) -> Callable:
def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
torch._dynamo.reset()
with inductor_config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
autotune_fallback_to_aten=False,
):
if self.extra_args.scaling_rowwise:
M, N = a.shape[0], b.shape[1]
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
out_dtype = torch.bfloat16
else:
scale_a = torch.tensor(1.0, device=a.device)
scale_b = torch.tensor(1.0, device=b.device)
out_dtype = torch.float16
f = lambda a, b: torch._scaled_mm(
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
)
compiled = torch.compile(f, dynamic=False)
compiled(a, b)

return lambda: compiled(a, b)

@register_benchmark()
def triton_fp8_gemm(self, a, b):
def triton_fp8_gemm(self, a, b, scale_a, scale_b):
return lambda: tutorial_matmul(a, b)

@register_benchmark(enabled=HAS_TMA)
def triton_persistent_fp8_gemm(self, a, b):
def triton_persistent_fp8_gemm(self, a, b, scale_a, scale_b):
return lambda: matmul_persistent(a, b)

@register_benchmark(enabled=HAS_TMA)
def triton_tma_persistent_fp8_gemm(self, a, b):
def triton_tma_persistent_fp8_gemm(self, a, b, scale_a, scale_b):
b = b.T.contiguous()
c, desc_a, desc_b, desc_c = allocate_matmul_tma(a, b)
return lambda: matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c)
Expand All @@ -155,7 +155,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
def nbytes(t):
return t.numel() * t.element_size()

a, b = example_inputs
a, b, _, _ = example_inputs
c = fn()
c = c[0] if isinstance(c, tuple) else c

Expand All @@ -168,7 +168,7 @@ def nbytes(t):
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b = example_inputs
a, b, _, _ = example_inputs
m, k = a.size()
_, n = b.size()
flops = 2 * m * n * k
Expand Down
Loading