diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index fbfead161a..6c8113e8cb 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -38,6 +38,14 @@ ) import torchao +from torchao.prototype.mx_formats.config import ( + MXGemmKernelChoice, +) +from torchao.prototype.mx_formats.inference_workflow import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, PerRow, @@ -80,40 +88,67 @@ def get_gemm_times( fast_accum: bool, recipe_name: Optional[str], ): - assert recipe_name in {"rowwise"}, ( - "Only support real benchmarks for 'rowwise' recipe for now" - ) device = torch.device("cuda") # bf16 time x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) - # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) - e4m3_dtype = torch.float8_e4m3fn - if torch.version.hip and torch.cuda.is_available() and is_MI300(): - e4m3_dtype = torch.float8_e4m3fnuz - d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 - A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) - B = ( - torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) - .view(d2) - .t() - .contiguous() - .t() - ) + if recipe_name in ("mxfp4_cutlass", "nvfp4"): + d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16 + A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view( + d1 + ) + B = ( + torch.randint(0, 255, (K // 2, N), device=device, dtype=torch.uint8) + .t() + .contiguous() + .t() + .view(d2) + ) + else: + e4m3_dtype = torch.float8_e4m3fn + if torch.version.hip and torch.cuda.is_available() and is_MI300(): + e4m3_dtype = torch.float8_e4m3fnuz + d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 + A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) + B = ( + torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) + .view(d2) + .t() + .contiguous() + .t() + ) + if recipe_name == "rowwise": scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + elif recipe_name == "mxfp8_cublas": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "mxfp4_cutlass": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "nvfp4": + scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn) + scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn) + else: assert False, "unsupported" def do_matmul(A, B): - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum - ) + if recipe_name == "mxfp4_cutlass": + return torchao.ops.mx_fp4_bf16(A, B, scale_a, scale_b) + if recipe_name == "nvfp4": + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False + ) + else: + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) @@ -259,12 +294,33 @@ def run( # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() - config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - # for now, use TORCH. In the future might be interesting - # to benchmark AUTO and FBGEMM. - kernel_preference=KernelPreference.TORCH, - ) + if recipe_name == "rowwise": + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + # for now, use TORCH. In the future might be interesting + # to benchmark AUTO and FBGEMM. + kernel_preference=KernelPreference.TORCH, + ) + elif recipe_name == "mxfp8_cublas": + config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + ) + elif recipe_name == "mxfp4_cutlass": + config = MXFPInferenceConfig( + activation_dtype=torch.float4_e2m1fn_x2, + weight_dtype=torch.float4_e2m1fn_x2, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + ) + elif recipe_name == "nvfp4": + config = NVFP4InferenceConfig( + mm_config=NVFP4MMConfig.DYNAMIC, + use_dynamic_per_tensor_scale=False, + ) + else: + assert False, "unsupported" + m_fp8_dyn = copy.deepcopy(m_orig) quantize_(m_fp8_dyn, config)