|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import argparse |
| 5 | +import math |
| 6 | +from contextlib import contextmanager |
| 7 | +from typing import Callable |
| 8 | +from unittest.mock import patch |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils |
| 13 | +from vllm.platforms import current_platform |
| 14 | + |
| 15 | + |
| 16 | +@contextmanager |
| 17 | +def _triton_mode(): |
| 18 | + """Temporarily force the Triton fallback path""" |
| 19 | + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): |
| 20 | + yield |
| 21 | + |
| 22 | + |
| 23 | +def _time_cuda( |
| 24 | + fn: Callable[[], tuple[torch.Tensor, torch.Tensor]], |
| 25 | + warmup_iters: int, |
| 26 | + bench_iters: int, |
| 27 | +) -> float: |
| 28 | + # warmup |
| 29 | + for _ in range(warmup_iters): |
| 30 | + fn() |
| 31 | + torch.cuda.synchronize() |
| 32 | + |
| 33 | + start = torch.cuda.Event(enable_timing=True) |
| 34 | + end = torch.cuda.Event(enable_timing=True) |
| 35 | + |
| 36 | + start.record() |
| 37 | + for _ in range(bench_iters): |
| 38 | + fn() |
| 39 | + end.record() |
| 40 | + torch.cuda.synchronize() |
| 41 | + |
| 42 | + return start.elapsed_time(end) / bench_iters # ms/iter |
| 43 | + |
| 44 | + |
| 45 | +def _run_single( |
| 46 | + shape: tuple[int, int], |
| 47 | + group_size: int, |
| 48 | + dtype: str, |
| 49 | + *, |
| 50 | + column_major: bool = False, |
| 51 | + scale_ue8m0: bool = False, |
| 52 | + warmup_iters: int, |
| 53 | + bench_iters: int, |
| 54 | +) -> None: |
| 55 | + num_tokens, hidden_dim = shape |
| 56 | + |
| 57 | + device = torch.device("cuda") |
| 58 | + torch.manual_seed(42) |
| 59 | + x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8 |
| 60 | + |
| 61 | + if dtype == "fp8": |
| 62 | + |
| 63 | + def cuda_impl(): |
| 64 | + return fp8_utils.per_token_group_quant_fp8( |
| 65 | + x, |
| 66 | + group_size, |
| 67 | + column_major_scales=column_major, |
| 68 | + use_ue8m0=scale_ue8m0, |
| 69 | + ) |
| 70 | + |
| 71 | + def triton_impl(): |
| 72 | + with _triton_mode(): |
| 73 | + return fp8_utils.per_token_group_quant_fp8( |
| 74 | + x, |
| 75 | + group_size, |
| 76 | + column_major_scales=column_major, |
| 77 | + use_ue8m0=scale_ue8m0, |
| 78 | + ) |
| 79 | + elif dtype == "int8": |
| 80 | + |
| 81 | + def cuda_impl(): |
| 82 | + return int8_utils.per_token_group_quant_int8(x, group_size) |
| 83 | + |
| 84 | + def triton_impl(): |
| 85 | + with _triton_mode(): |
| 86 | + return int8_utils.per_token_group_quant_int8(x, group_size) |
| 87 | + else: |
| 88 | + raise ValueError("dtype must be 'fp8' or 'int8'") |
| 89 | + |
| 90 | + cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters) |
| 91 | + triton_ms = _time_cuda(triton_impl, warmup_iters, bench_iters) |
| 92 | + |
| 93 | + speedup = triton_ms / cuda_ms if cuda_ms else math.inf |
| 94 | + |
| 95 | + cfg_desc = ( |
| 96 | + f"shape={shape} gs={group_size:<3} col_major={column_major:<5} " |
| 97 | + f"ue8m0={scale_ue8m0:<5} dtype={dtype}" |
| 98 | + ) |
| 99 | + print( |
| 100 | + f"{cfg_desc:55} | CUDA {cuda_ms:7.3f} ms | Triton {triton_ms:7.3f} ms | " |
| 101 | + f"speed-up ×{speedup:5.2f}" |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +def parse_args(): |
| 106 | + parser = argparse.ArgumentParser() |
| 107 | + parser.add_argument("--warmup-iters", type=int, default=10) |
| 108 | + parser.add_argument("--bench-iters", type=int, default=100) |
| 109 | + parser.add_argument("--dtype", choices=["fp8", "int8", "both"], default="both") |
| 110 | + return parser.parse_args() |
| 111 | + |
| 112 | + |
| 113 | +if __name__ == "__main__": |
| 114 | + if not current_platform.is_cuda(): |
| 115 | + raise RuntimeError("CUDA device is required to run this benchmark.") |
| 116 | + |
| 117 | + args = parse_args() |
| 118 | + warmup_iters, bench_iters = args.warmup_iters, args.bench_iters |
| 119 | + |
| 120 | + shapes = [(32, 128), (64, 256), (16, 512)] |
| 121 | + group_sizes = [64, 128] |
| 122 | + |
| 123 | + dtypes = ["fp8", "int8"] if args.dtype == "both" else [args.dtype] |
| 124 | + |
| 125 | + header = ( |
| 126 | + "Configuration".ljust(55) |
| 127 | + + " | " |
| 128 | + + "CUDA (ms)".center(12) |
| 129 | + + " | " |
| 130 | + + "Triton (ms)".center(13) |
| 131 | + + " | " |
| 132 | + + "Speed-up" |
| 133 | + ) |
| 134 | + print(header) |
| 135 | + print("-" * len(header)) |
| 136 | + |
| 137 | + for dtype in dtypes: |
| 138 | + for shape in shapes: |
| 139 | + for gs in group_sizes: |
| 140 | + if dtype == "fp8": |
| 141 | + for col_major in (False, True): |
| 142 | + for ue8m0 in (False, True): |
| 143 | + _run_single( |
| 144 | + shape, |
| 145 | + gs, |
| 146 | + dtype, |
| 147 | + column_major=col_major, |
| 148 | + scale_ue8m0=ue8m0, |
| 149 | + warmup_iters=warmup_iters, |
| 150 | + bench_iters=bench_iters, |
| 151 | + ) |
| 152 | + else: # INT8 has no col-major / ue8m0 switches |
| 153 | + _run_single( |
| 154 | + shape, |
| 155 | + gs, |
| 156 | + dtype, |
| 157 | + warmup_iters=warmup_iters, |
| 158 | + bench_iters=bench_iters, |
| 159 | + ) |
0 commit comments