| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD 3-Clause license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py  | 
 | 7 | +import argparse  | 
 | 8 | +import itertools  | 
 | 9 | +from dataclasses import dataclass  | 
 | 10 | +from typing import List  | 
 | 11 | + | 
 | 12 | +import torch  | 
 | 13 | +from tabulate import tabulate  | 
 | 14 | +from tqdm import tqdm  | 
 | 15 | +from utils import benchmark_cuda_function_in_microseconds  | 
 | 16 | + | 
 | 17 | +from torchao.float8.config import ScalingGranularity  | 
 | 18 | +from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated  | 
 | 19 | +from torchao.prototype.moe_training.utils import generate_jagged_offs  | 
 | 20 | +from torchao.prototype.mx_formats.mx_tensor import to_mx  | 
 | 21 | +from torchao.prototype.mx_formats.utils import (  | 
 | 22 | +    to_blocked_per_group_2d,  | 
 | 23 | +    to_blocked_per_group_3d,  | 
 | 24 | +)  | 
 | 25 | + | 
 | 26 | +device = torch.device("cuda")  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +@dataclass(frozen=True)  | 
 | 30 | +class ExperimentConfig:  | 
 | 31 | +    e: int  | 
 | 32 | +    m: int  | 
 | 33 | +    n: int  | 
 | 34 | +    k: int  | 
 | 35 | + | 
 | 36 | + | 
 | 37 | +@dataclass(frozen=True)  | 
 | 38 | +class ExperimentResult:  | 
 | 39 | +    bf16_us: float  | 
 | 40 | +    fp8_rowwise_us: float  | 
 | 41 | +    mxfp8_us: float  | 
 | 42 | + | 
 | 43 | + | 
 | 44 | +@dataclass(frozen=True)  | 
 | 45 | +class Experiment:  | 
 | 46 | +    config: ExperimentConfig  | 
 | 47 | +    result: ExperimentResult  | 
 | 48 | + | 
 | 49 | + | 
 | 50 | +def get_configs() -> List[ExperimentConfig]:  | 
 | 51 | +    # Llama4 shapes  | 
 | 52 | +    M = [16640]  | 
 | 53 | +    K = [5120]  | 
 | 54 | +    N = [8192]  | 
 | 55 | +    E = [16]  | 
 | 56 | +    configs = []  | 
 | 57 | +    for e, m, n, k in itertools.product(  | 
 | 58 | +        E,  | 
 | 59 | +        M,  | 
 | 60 | +        N,  | 
 | 61 | +        K,  | 
 | 62 | +    ):  | 
 | 63 | +        configs.append(  | 
 | 64 | +            ExperimentConfig(  | 
 | 65 | +                e=e,  | 
 | 66 | +                m=m,  | 
 | 67 | +                n=n,  | 
 | 68 | +                k=k,  | 
 | 69 | +            )  | 
 | 70 | +        )  | 
 | 71 | +    return configs  | 
 | 72 | + | 
 | 73 | + | 
 | 74 | +def run_experiment(  | 
 | 75 | +    config: ExperimentConfig, args: argparse.Namespace  | 
 | 76 | +) -> ExperimentResult:  | 
 | 77 | +    e, m, n, k = config.e, config.m, config.n, config.k  | 
 | 78 | + | 
 | 79 | +    # define test inputs  | 
 | 80 | +    A = torch.randn(  | 
 | 81 | +        (m, k),  | 
 | 82 | +        dtype=torch.bfloat16,  | 
 | 83 | +        device=device,  | 
 | 84 | +    )  | 
 | 85 | +    B_t = torch.randn(  | 
 | 86 | +        (e, n, k),  | 
 | 87 | +        dtype=torch.bfloat16,  | 
 | 88 | +        device=device,  | 
 | 89 | +        requires_grad=True,  | 
 | 90 | +    ).transpose(-2, -1)  | 
 | 91 | + | 
 | 92 | +    # Configure groups  | 
 | 93 | +    n_groups = e  | 
 | 94 | +    Mg = A.shape[0]  | 
 | 95 | +    alignment_size = 16  | 
 | 96 | +    offs = generate_jagged_offs(n_groups, Mg, multiple_of=alignment_size)  | 
 | 97 | + | 
 | 98 | +    # benchmark bf16 grouped mm  | 
 | 99 | +    bf16_us = benchmark_cuda_function_in_microseconds(  | 
 | 100 | +        torch._grouped_mm,  | 
 | 101 | +        A,  | 
 | 102 | +        B_t,  | 
 | 103 | +        offs,  | 
 | 104 | +        out_dtype=torch.bfloat16,  | 
 | 105 | +    )  | 
 | 106 | + | 
 | 107 | +    # bench fp8 rowwise grouped mm  | 
 | 108 | +    fp8_rowwise_us = bench_fp8_rowwise_grouped_mm(A, B_t, offs)  | 
 | 109 | + | 
 | 110 | +    # benchmark mxfp8 grouped mm  | 
 | 111 | +    mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs)  | 
 | 112 | + | 
 | 113 | +    return ExperimentResult(  | 
 | 114 | +        bf16_us=round(bf16_us, 3),  | 
 | 115 | +        fp8_rowwise_us=round(fp8_rowwise_us, 3),  | 
 | 116 | +        mxfp8_us=round(mxfp8_us, 3),  | 
 | 117 | +    )  | 
 | 118 | + | 
 | 119 | + | 
 | 120 | +def print_results(experiments: List[Experiment]):  | 
 | 121 | +    headers = [  | 
 | 122 | +        "E",  | 
 | 123 | +        "M",  | 
 | 124 | +        "N",  | 
 | 125 | +        "K",  | 
 | 126 | +        "bf16_time_us",  | 
 | 127 | +        "fp8_rowwise_time_us",  | 
 | 128 | +        "mxfp8_time_us",  | 
 | 129 | +    ]  | 
 | 130 | +    rows = []  | 
 | 131 | +    for experiment in experiments:  | 
 | 132 | +        rows.append(  | 
 | 133 | +            [  | 
 | 134 | +                experiment.config.e,  | 
 | 135 | +                experiment.config.m,  | 
 | 136 | +                experiment.config.n,  | 
 | 137 | +                experiment.config.k,  | 
 | 138 | +                experiment.result.bf16_us,  | 
 | 139 | +                experiment.result.fp8_rowwise_us,  | 
 | 140 | +                experiment.result.mxfp8_us,  | 
 | 141 | +            ]  | 
 | 142 | +        )  | 
 | 143 | +    print(tabulate(rows, headers=headers))  | 
 | 144 | + | 
 | 145 | + | 
 | 146 | +# benchmark fp8 grouped mm  | 
 | 147 | +def bench_fp8_rowwise_grouped_mm(A, B_t, offs) -> float:  | 
 | 148 | +    # Convert A to float8, row-major for left operand of grouped GEMM.  | 
 | 149 | +    A_scales = tensor_to_scale(  | 
 | 150 | +        A,  | 
 | 151 | +        torch.float8_e4m3fn,  | 
 | 152 | +        scaling_granularity=ScalingGranularity.AXISWISE,  | 
 | 153 | +        axiswise_dim=-1,  | 
 | 154 | +        round_scales_to_power_of_2=True,  | 
 | 155 | +    )  | 
 | 156 | +    A_scaled = A.to(torch.float32) * A_scales  | 
 | 157 | +    A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)  | 
 | 158 | + | 
 | 159 | +    # Convert B_t to float8, column-major for right operand of grouped GEMM.  | 
 | 160 | +    B_t_scales = tensor_to_scale(  | 
 | 161 | +        B_t,  | 
 | 162 | +        torch.float8_e4m3fn,  | 
 | 163 | +        scaling_granularity=ScalingGranularity.AXISWISE,  | 
 | 164 | +        axiswise_dim=-2,  | 
 | 165 | +        round_scales_to_power_of_2=True,  | 
 | 166 | +    )  | 
 | 167 | +    B_t_scaled = B_t.to(torch.float32) * B_t_scales  | 
 | 168 | +    B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)  | 
 | 169 | + | 
 | 170 | +    # Bench the gemm  | 
 | 171 | +    fp8_us = benchmark_cuda_function_in_microseconds(  | 
 | 172 | +        torch._scaled_grouped_mm,  | 
 | 173 | +        A_fp8_row_major,  | 
 | 174 | +        B_t_fp8_col_major,  | 
 | 175 | +        A_scales.squeeze(1).reciprocal(),  | 
 | 176 | +        B_t_scales.squeeze(1).reciprocal(),  | 
 | 177 | +        offs,  | 
 | 178 | +        out_dtype=torch.bfloat16,  | 
 | 179 | +        use_fast_accum=True,  | 
 | 180 | +    )  | 
 | 181 | +    return fp8_us  | 
 | 182 | + | 
 | 183 | + | 
 | 184 | +def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:  | 
 | 185 | +    # A_mx shape: (M, K)  | 
 | 186 | +    # A_scale shape: (M, K//block_size)  | 
 | 187 | +    A_scales, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)  | 
 | 188 | + | 
 | 189 | +    # B_mx shape: (E, N, K)  | 
 | 190 | +    # B_scale shape: (E, N, K//block_size)  | 
 | 191 | +    B_scales, B_fp8 = to_mx(  | 
 | 192 | +        B_t.transpose(-2, -1),  | 
 | 193 | +        elem_dtype=torch.float8_e4m3fn,  | 
 | 194 | +        block_size=block_size,  | 
 | 195 | +    )  | 
 | 196 | + | 
 | 197 | +    # Convert scales for each group to blocked format.  | 
 | 198 | +    Mg, K = A_fp8.shape  | 
 | 199 | +    A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d(  | 
 | 200 | +        A_scales, offs, Mg, K  | 
 | 201 | +    )  | 
 | 202 | +    B_scales_blocked = to_blocked_per_group_3d(B_scales)  | 
 | 203 | + | 
 | 204 | +    # From this, we compute `group_sizes` and `starting_row_after_padding`:  | 
 | 205 | +    # group_sizes = [32, 32, 64]  | 
 | 206 | +    # starting_row_after_padding = [0, 32, 64, 128]  | 
 | 207 | +    zero = torch.tensor([0], dtype=offs.dtype, device=offs.device)  | 
 | 208 | +    group_sizes = torch.diff(offs, prepend=zero).to(torch.int64)  | 
 | 209 | + | 
 | 210 | +    # Run the grouped mm  | 
 | 211 | +    mxfp8_us = benchmark_cuda_function_in_microseconds(  | 
 | 212 | +        torch.ops.fbgemm.mx8mx8bf16_grouped_stacked,  | 
 | 213 | +        A_fp8,  | 
 | 214 | +        B_fp8,  | 
 | 215 | +        A_scales_blocked,  | 
 | 216 | +        B_scales_blocked,  | 
 | 217 | +        group_sizes,  | 
 | 218 | +        starting_row_after_padding=starting_row_after_padding,  | 
 | 219 | +    )  | 
 | 220 | +    return mxfp8_us  | 
 | 221 | + | 
 | 222 | + | 
 | 223 | +def main(args: argparse.Namespace):  | 
 | 224 | +    torch.random.manual_seed(123)  | 
 | 225 | +    configs = get_configs()  | 
 | 226 | +    results = []  | 
 | 227 | +    for config in tqdm(configs):  | 
 | 228 | +        result = run_experiment(config, args)  | 
 | 229 | +        results.append(Experiment(config=config, result=result))  | 
 | 230 | + | 
 | 231 | +    # Use Tabulate to print results  | 
 | 232 | +    print_results(results)  | 
 | 233 | + | 
 | 234 | + | 
 | 235 | +if __name__ == "__main__":  | 
 | 236 | +    arg_parser = argparse.ArgumentParser()  | 
 | 237 | +    args = arg_parser.parse_args()  | 
 | 238 | +    main(args)  | 
0 commit comments