Skip to content

Commit ca77a33

Browse files
committed
Revert "[TUTORIAL][03] use float8_e4m3fn(uz) instead of e5m2 and add PyTorch comparison (#6850)"
This reverts commit 7f3e938.
1 parent f77a8dd commit ca77a33

File tree

1 file changed

+10
-19
lines changed

1 file changed

+10
-19
lines changed

python/tutorials/03-matrix-multiplication.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -407,17 +407,15 @@ def matmul(a, b, activation=""):
407407
else:
408408
exit("❌ Triton and Torch differ")
409409

410-
TORCH_HAS_FP8 = hasattr(torch, "float8_e4m3fn") or hasattr(torch, "float8_e4m3fnuz")
411-
412-
if TORCH_HAS_FP8:
413-
fp8_dtype = torch.float8_e4m3fn if is_cuda() else torch.float8_e4m3fnuz
410+
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
411+
if TORCH_HAS_FP8 and is_cuda():
414412
torch.manual_seed(0)
415413
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
416414
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
417-
a = a.to(fp8_dtype)
415+
a = a.to(torch.float8_e5m2)
418416
# pre-transpose b for efficiency.
419417
b = b.T
420-
b = b.to(fp8_dtype)
418+
b = b.to(torch.float8_e5m2)
421419
triton_output = matmul(a, b)
422420
torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
423421
print(f"triton_output_with_fp8_inputs={triton_output}")
@@ -441,7 +439,7 @@ def matmul(a, b, activation=""):
441439

442440
configs = []
443441
for fp8_inputs in [False, True]:
444-
if fp8_inputs and (not TORCH_HAS_FP8):
442+
if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
445443
continue
446444
configs.append(
447445
triton.testing.Benchmark(
@@ -450,8 +448,8 @@ def matmul(a, b, activation=""):
450448
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
451449
# Possible values for `line_arg`
452450
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
453-
line_vals=[ref_lib.lower(), "triton"], # Label name for the lines
454-
line_names=[ref_lib, "Triton"], # Line styles
451+
line_vals=["triton"] if fp8_inputs else [ref_lib.lower(), "triton"], # Label name for the lines
452+
line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"], # Line styles
455453
styles=[("green", "-"), ("blue", "-")],
456454
ylabel="TFLOPS", # Label name for the y-axis
457455
plot_name="matmul-performance-" +
@@ -465,19 +463,12 @@ def benchmark(M, N, K, provider, fp8_inputs):
465463
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
466464
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
467465
if TORCH_HAS_FP8 and fp8_inputs:
468-
fp8_dtype = torch.float8_e4m3fn if is_cuda() else torch.float8_e4m3fnuz
469-
a = a.to(fp8_dtype)
466+
a = a.to(torch.float8_e5m2)
470467
b = b.T
471-
b = b.to(fp8_dtype)
468+
b = b.to(torch.float8_e5m2)
472469
quantiles = [0.5, 0.2, 0.8]
473470
if provider == ref_lib.lower():
474-
if fp8_inputs:
475-
one_device = torch.tensor(1., device=a.device, dtype=torch.float32)
476-
ref_fn = lambda: torch._scaled_mm(a, b, scale_a=one_device, scale_b=one_device, out_dtype=torch.float16,
477-
use_fast_accum=True)
478-
else:
479-
ref_fn = lambda: torch.matmul(a, b)
480-
ms, min_ms, max_ms = triton.testing.do_bench(ref_fn, quantiles=quantiles)
471+
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
481472
if provider == 'triton':
482473
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
483474
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)

0 commit comments

Comments
 (0)