@@ -407,17 +407,15 @@ def matmul(a, b, activation=""):
407407else :
408408 exit ("❌ Triton and Torch differ" )
409409
410- TORCH_HAS_FP8 = hasattr (torch , "float8_e4m3fn" ) or hasattr (torch , "float8_e4m3fnuz" )
411-
412- if TORCH_HAS_FP8 :
413- fp8_dtype = torch .float8_e4m3fn if is_cuda () else torch .float8_e4m3fnuz
410+ TORCH_HAS_FP8 = hasattr (torch , "float8_e5m2" )
411+ if TORCH_HAS_FP8 and is_cuda ():
414412 torch .manual_seed (0 )
415413 a = torch .randn ((512 , 512 ), device = DEVICE , dtype = torch .float16 )
416414 b = torch .randn ((512 , 512 ), device = DEVICE , dtype = torch .float16 )
417- a = a .to (fp8_dtype )
415+ a = a .to (torch . float8_e5m2 )
418416 # pre-transpose b for efficiency.
419417 b = b .T
420- b = b .to (fp8_dtype )
418+ b = b .to (torch . float8_e5m2 )
421419 triton_output = matmul (a , b )
422420 torch_output = torch .matmul (a .to (torch .float16 ), b .to (torch .float16 ))
423421 print (f"triton_output_with_fp8_inputs={ triton_output } " )
@@ -441,7 +439,7 @@ def matmul(a, b, activation=""):
441439
442440configs = []
443441for fp8_inputs in [False , True ]:
444- if fp8_inputs and (not TORCH_HAS_FP8 ):
442+ if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda () ):
445443 continue
446444 configs .append (
447445 triton .testing .Benchmark (
@@ -450,8 +448,8 @@ def matmul(a, b, activation=""):
450448 line_arg = "provider" , # Argument name whose value corresponds to a different line in the plot
451449 # Possible values for `line_arg`
452450 # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
453- line_vals = [ref_lib .lower (), "triton" ], # Label name for the lines
454- line_names = [ref_lib , "Triton" ], # Line styles
451+ line_vals = ["triton" ] if fp8_inputs else [ ref_lib .lower (), "triton" ], # Label name for the lines
452+ line_names = ["Triton" ] if fp8_inputs else [ ref_lib , "Triton" ], # Line styles
455453 styles = [("green" , "-" ), ("blue" , "-" )],
456454 ylabel = "TFLOPS" , # Label name for the y-axis
457455 plot_name = "matmul-performance-" +
@@ -465,19 +463,12 @@ def benchmark(M, N, K, provider, fp8_inputs):
465463 a = torch .randn ((M , K ), device = DEVICE , dtype = torch .float16 )
466464 b = torch .randn ((K , N ), device = DEVICE , dtype = torch .float16 )
467465 if TORCH_HAS_FP8 and fp8_inputs :
468- fp8_dtype = torch .float8_e4m3fn if is_cuda () else torch .float8_e4m3fnuz
469- a = a .to (fp8_dtype )
466+ a = a .to (torch .float8_e5m2 )
470467 b = b .T
471- b = b .to (fp8_dtype )
468+ b = b .to (torch . float8_e5m2 )
472469 quantiles = [0.5 , 0.2 , 0.8 ]
473470 if provider == ref_lib .lower ():
474- if fp8_inputs :
475- one_device = torch .tensor (1. , device = a .device , dtype = torch .float32 )
476- ref_fn = lambda : torch ._scaled_mm (a , b , scale_a = one_device , scale_b = one_device , out_dtype = torch .float16 ,
477- use_fast_accum = True )
478- else :
479- ref_fn = lambda : torch .matmul (a , b )
480- ms , min_ms , max_ms = triton .testing .do_bench (ref_fn , quantiles = quantiles )
471+ ms , min_ms , max_ms = triton .testing .do_bench (lambda : torch .matmul (a , b ), quantiles = quantiles )
481472 if provider == 'triton' :
482473 ms , min_ms , max_ms = triton .testing .do_bench (lambda : matmul (a , b ), quantiles = quantiles )
483474 perf = lambda ms : 2 * M * N * K * 1e-12 / (ms * 1e-3 )
0 commit comments