diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 4bf54538df..547b0a40e4 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -180,7 +180,7 @@ def get_gemm_times( scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) else: - assert False, "TODO add cutlass mx gemm here" + assert False, f"unsupported {float8_recipe_name=} {mx_recipe_name=}" def do_matmul(A, B): return torch._scaled_mm( @@ -233,6 +233,20 @@ def run( print(f"mx_recipe_name: {mx_recipe_name}") print(f"enable_fusion_modeling: {enable_fusion_modeling}") + assert mx_recipe_name in ( + # real mxfp8_cublas recipe + "mxfp8_cublas", + # real mxfp8_cublas_rceil recipe + "mxfp8_cublas_rceil", + # modeling of what mxfp8 with 32x32 block size and without gemm + # operand layout restrictions would look like + "mxfp8_32x32_flexible_gemm_layout", + # modeling of what mxfp8 with 32x32 block size for weight + "mxfp8_32x32_weight", + # real mxfp4_cutlass recipe + "mxfp4_cutlass", + ), f"unsupported {mx_recipe_name=}" + M, K, N = sympy.symbols("M K N") fp8_ovhd_time_sympy = get_float8_mem_sympy( @@ -309,7 +323,11 @@ def run( rb_fp8_gemm_ratio = -1 if do_benchmarks: - assert mx_recipe_name != "mxfp4_cutlass", "unsupported" + assert mx_recipe_name not in ( + "mxfp4_cutlass", + "mxfp8_32x32_flexible_gemm_layout", + "mxfp8_32x32_weight", + ), f"do_benchmarks unsupported with {mx_recipe_name=}" # TODO(future): make the bf16 gemm times exactly match the e2e # benchmarks, there is a slight deviation, probably related to gemm diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index f57705333a..e391a4d44b 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -187,13 +187,53 @@ def get_tensor_memory_traffic_ovhd_s( else: assert False, "unsupported" + elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout": + # modeling the following: + # 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense + # across dim0 and dim1 + # 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in + # PyTorch right now) + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8_dim0 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw] + + elif mx_recipe_name == "mxfp8_32x32_weight": + # modeling the following: + # 1. mxfp8 scaling with 32x32 weights, so the format makes sense + # across dim0 and dim1. input and grad_output still 1x32. + + if tensor_role in ("input", "grad_output"): + # TODO(future): update all of the mx rooflines to just read once + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_bf16 -> x_mxfp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + + elif tensor_role == "weight": + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1 + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2 + + else: + assert False, "unsupported" + + res_bytes = [kernel_1_rw, kernel_2_rw] + else: assert mx_recipe_name in ( "mxfp8_emulated", "mxfp8_cublas", "mxfp8_cublas_rceil", "mxfp4_cutlass", - ), "unsupported" + ), f"unsupported {mx_recipe_name=}" # For now, assume that we can't profitably fuse kernel 1 and kernel 2 # x_bf16 = ... # kernel 1: x_bf16 -> x_mxfp8_dim0