|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models). |
| 5 | +# It generates test data, runs benchmarks, and saves results to a CSV file. |
| 6 | +# |
| 7 | +# The CSV file (named with current date/time) contains these columns: |
| 8 | +# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position, |
| 9 | +# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99, |
| 10 | +# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max, |
| 11 | +# speedup |
| 12 | +# |
| 13 | +# == Usage Examples == |
| 14 | +# |
| 15 | +# Single model benchmark: |
| 16 | +# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \ |
| 17 | +# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 |
| 18 | +# |
| 19 | +# All models benchmark: |
| 20 | +# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \ |
| 21 | +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 |
| 22 | +# |
| 23 | +# All models with different TP sizes: |
| 24 | +# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \ |
| 25 | +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 |
| 26 | +# |
| 27 | +# All models with different token counts: |
| 28 | +# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \ |
| 29 | +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384 |
| 30 | +import csv |
| 31 | +import os |
| 32 | +import time |
| 33 | +from datetime import datetime |
| 34 | +from typing import Any |
| 35 | + |
| 36 | +import numpy as np |
| 37 | +import torch |
| 38 | + |
| 39 | +from vllm.model_executor.layers.rotary_embedding import get_rope |
| 40 | +from vllm.platforms import current_platform |
| 41 | +from vllm.transformers_utils.config import get_config |
| 42 | +from vllm.utils import FlexibleArgumentParser |
| 43 | + |
| 44 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 45 | + |
| 46 | + |
| 47 | +def generate_test_data( |
| 48 | + num_tokens: int, |
| 49 | + num_q_heads: int, |
| 50 | + num_kv_heads: int, |
| 51 | + head_size: int, |
| 52 | + max_position_embeddings: int, |
| 53 | + dtype: torch.dtype, |
| 54 | + device: torch.device, |
| 55 | +): |
| 56 | + """Generate test data for given configuration.""" |
| 57 | + # Create 2D positions (3, num_tokens) for multimodal case |
| 58 | + positions = torch.randint( |
| 59 | + 0, max_position_embeddings // 4, (3, num_tokens), device=device |
| 60 | + ) |
| 61 | + |
| 62 | + # Create query and key tensors |
| 63 | + query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device) |
| 64 | + key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device) |
| 65 | + |
| 66 | + return positions, query, key |
| 67 | + |
| 68 | + |
| 69 | +def calculate_stats(times: list[float]) -> dict[str, float]: |
| 70 | + """Calculate statistics from a list of times.""" |
| 71 | + times_array = np.array(times) |
| 72 | + return { |
| 73 | + "mean": np.mean(times_array), |
| 74 | + "median": np.median(times_array), |
| 75 | + "p99": np.percentile(times_array, 99), |
| 76 | + "min": np.min(times_array), |
| 77 | + "max": np.max(times_array), |
| 78 | + } |
| 79 | + |
| 80 | + |
| 81 | +def benchmark_mrope( |
| 82 | + model_name: str, |
| 83 | + num_tokens: int, |
| 84 | + head_dim: int, |
| 85 | + tp_size: int, |
| 86 | + num_heads: int, |
| 87 | + num_kv_heads: int, |
| 88 | + max_position: int = 8192, |
| 89 | + rope_theta: float = 10000, |
| 90 | + is_neox_style: bool = True, |
| 91 | + rope_scaling: dict[str, Any] = None, |
| 92 | + dtype: torch.dtype = torch.bfloat16, |
| 93 | + seed: int = 0, |
| 94 | + warmup_iter: int = 10, |
| 95 | + benchmark_iter: int = 100, |
| 96 | + csv_writer=None, |
| 97 | +): |
| 98 | + current_platform.seed_everything(seed) |
| 99 | + torch.set_default_device(device) |
| 100 | + # the parameters to compute the q k v size based on tp_size |
| 101 | + mrope_helper_class = get_rope( |
| 102 | + head_size=head_dim, |
| 103 | + rotary_dim=head_dim, |
| 104 | + max_position=max_position, |
| 105 | + base=rope_theta, |
| 106 | + is_neox_style=is_neox_style, |
| 107 | + rope_scaling=rope_scaling, |
| 108 | + dtype=dtype, |
| 109 | + ).to(device=device) |
| 110 | + |
| 111 | + print(80 * "=") |
| 112 | + print( |
| 113 | + f"Evaluating model: {model_name} " |
| 114 | + f"with tp_size: {tp_size} " |
| 115 | + f"and num_tokens: {num_tokens}, " |
| 116 | + f"dtype: {dtype}" |
| 117 | + ) |
| 118 | + |
| 119 | + # create q k v input tensors |
| 120 | + # create rotary pos emb input tensors |
| 121 | + positions, query, key = generate_test_data( |
| 122 | + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device |
| 123 | + ) |
| 124 | + |
| 125 | + # Warm up |
| 126 | + for _ in range(warmup_iter): |
| 127 | + mrope_helper_class.forward_native( |
| 128 | + positions, |
| 129 | + query.clone(), |
| 130 | + key.clone(), |
| 131 | + ) |
| 132 | + |
| 133 | + mrope_helper_class.forward_cuda( |
| 134 | + positions, |
| 135 | + query.clone(), |
| 136 | + key.clone(), |
| 137 | + ) |
| 138 | + |
| 139 | + torch.cuda.synchronize() |
| 140 | + |
| 141 | + # Time reference implementation |
| 142 | + torch_times = [] |
| 143 | + for _ in range(benchmark_iter): |
| 144 | + query_clone = query.clone() |
| 145 | + key_clone = key.clone() |
| 146 | + torch.cuda.synchronize() |
| 147 | + start_time = time.time() |
| 148 | + |
| 149 | + mrope_helper_class.forward_native( |
| 150 | + positions, |
| 151 | + query_clone, |
| 152 | + key_clone, |
| 153 | + ) |
| 154 | + |
| 155 | + torch.cuda.synchronize() |
| 156 | + torch_times.append(time.time() - start_time) |
| 157 | + |
| 158 | + # Time triton kernel implementation |
| 159 | + triton_times = [] |
| 160 | + for _ in range(benchmark_iter): |
| 161 | + query_clone = query.clone() |
| 162 | + key_clone = key.clone() |
| 163 | + torch.cuda.synchronize() |
| 164 | + start_time = time.time() |
| 165 | + mrope_helper_class.forward_cuda( |
| 166 | + positions, |
| 167 | + query_clone, |
| 168 | + key_clone, |
| 169 | + ) |
| 170 | + torch.cuda.synchronize() |
| 171 | + triton_times.append(time.time() - start_time) |
| 172 | + |
| 173 | + # Calculate statistics |
| 174 | + torch_stats = calculate_stats(torch_times) |
| 175 | + triton_stats = calculate_stats(triton_times) |
| 176 | + print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):") |
| 177 | + |
| 178 | + print( |
| 179 | + f"Torch implementation: " |
| 180 | + f"mean={torch_stats['mean']:.8f}s, " |
| 181 | + f"median={torch_stats['median']:.8f}s, " |
| 182 | + f"p99={torch_stats['p99']:.8f}s" |
| 183 | + ) |
| 184 | + |
| 185 | + print( |
| 186 | + f"Triton implementation: " |
| 187 | + f"mean={triton_stats['mean']:.8f}s, " |
| 188 | + f"median={triton_stats['median']:.8f}s, " |
| 189 | + f"p99={triton_stats['p99']:.8f}s" |
| 190 | + ) |
| 191 | + |
| 192 | + print( |
| 193 | + f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x" |
| 194 | + ) |
| 195 | + |
| 196 | + # Write to CSV |
| 197 | + if csv_writer: |
| 198 | + row = [ |
| 199 | + model_name, |
| 200 | + tp_size, |
| 201 | + num_tokens, |
| 202 | + num_heads, |
| 203 | + num_kv_heads, |
| 204 | + head_dim, |
| 205 | + max_position, |
| 206 | + rope_theta, |
| 207 | + is_neox_style, |
| 208 | + str(rope_scaling), |
| 209 | + str(dtype).split(".")[-1], |
| 210 | + torch_stats["mean"], |
| 211 | + torch_stats["median"], |
| 212 | + torch_stats["p99"], |
| 213 | + torch_stats["min"], |
| 214 | + torch_stats["max"], |
| 215 | + triton_stats["mean"], |
| 216 | + triton_stats["median"], |
| 217 | + triton_stats["p99"], |
| 218 | + triton_stats["min"], |
| 219 | + triton_stats["max"], |
| 220 | + torch_stats["mean"] / triton_stats["mean"], # speedup |
| 221 | + ] |
| 222 | + csv_writer.writerow(row) |
| 223 | + |
| 224 | + return torch_stats, triton_stats |
| 225 | + |
| 226 | + |
| 227 | +if __name__ == "__main__": |
| 228 | + parser = FlexibleArgumentParser( |
| 229 | + description="Benchmark the rotary embedding kernels." |
| 230 | + ) |
| 231 | + parser.add_argument("--model-name", type=str, default="") |
| 232 | + parser.add_argument("--tp-size", type=int, default=1) |
| 233 | + parser.add_argument("--warmup-iter", type=int, default=10) |
| 234 | + parser.add_argument("--benchmark-iter", type=int, default=100) |
| 235 | + parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16") |
| 236 | + parser.add_argument("--seed", type=int, default=0) |
| 237 | + parser.add_argument("--num-tokens", type=int, nargs="+", required=False) |
| 238 | + parser.add_argument("--trust-remote-code", action="store_true") |
| 239 | + parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv") |
| 240 | + args = parser.parse_args() |
| 241 | + print(args) |
| 242 | + |
| 243 | + # Create CSV file for results |
| 244 | + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 245 | + csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv" |
| 246 | + |
| 247 | + with open(csv_filename, "w", newline="") as csvfile: |
| 248 | + csv_writer = csv.writer(csvfile) |
| 249 | + # Write header |
| 250 | + header = [ |
| 251 | + "model_name", |
| 252 | + "tp_size", |
| 253 | + "num_tokens", |
| 254 | + "num_heads", |
| 255 | + "num_kv_heads", |
| 256 | + "head_dim", |
| 257 | + "max_position", |
| 258 | + "rope_theta", |
| 259 | + "is_neox_style", |
| 260 | + "rope_scaling", |
| 261 | + "dtype", |
| 262 | + "torch_mean", |
| 263 | + "torch_median", |
| 264 | + "torch_p99", |
| 265 | + "torch_min", |
| 266 | + "torch_max", |
| 267 | + "triton_mean", |
| 268 | + "triton_median", |
| 269 | + "triton_p99", |
| 270 | + "triton_min", |
| 271 | + "triton_max", |
| 272 | + "speedup", |
| 273 | + ] |
| 274 | + csv_writer.writerow(header) |
| 275 | + |
| 276 | + model_tp_dict = {} |
| 277 | + if args.model_name == "": |
| 278 | + model_tp_dict = { |
| 279 | + "Qwen/Qwen2-VL-2B-Instruct": [1], |
| 280 | + "Qwen/Qwen2-VL-7B-Instruct": [1], |
| 281 | + "Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8], |
| 282 | + "Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8], |
| 283 | + "Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8], |
| 284 | + "Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8], |
| 285 | + } |
| 286 | + else: |
| 287 | + model_tp_dict[args.model_name] = [args.tp_size] |
| 288 | + |
| 289 | + if args.num_tokens is None: |
| 290 | + num_tokens_list = [2**i for i in range(0, 18)] |
| 291 | + else: |
| 292 | + num_tokens_list = args.num_tokens |
| 293 | + |
| 294 | + for model_name, tp_list in model_tp_dict.items(): |
| 295 | + config = get_config(model_name, trust_remote_code=args.trust_remote_code) |
| 296 | + for tp_size in tp_list: |
| 297 | + # get the model config |
| 298 | + total_num_kv_heads = config.num_key_value_heads |
| 299 | + total_num_heads = config.num_attention_heads |
| 300 | + num_heads = total_num_heads // tp_size |
| 301 | + num_kv_heads = max(1, total_num_kv_heads // tp_size) |
| 302 | + head_dim = config.hidden_size // total_num_heads |
| 303 | + q_size = num_heads * head_dim |
| 304 | + kv_size = num_kv_heads * head_dim |
| 305 | + is_neox_style = True |
| 306 | + rope_theta = config.rope_theta |
| 307 | + max_position = config.max_position_embeddings |
| 308 | + |
| 309 | + for num_tokens in num_tokens_list: |
| 310 | + benchmark_mrope( |
| 311 | + model_name=model_name, |
| 312 | + num_tokens=num_tokens, |
| 313 | + head_dim=head_dim, |
| 314 | + tp_size=tp_size, |
| 315 | + num_heads=num_heads, |
| 316 | + num_kv_heads=num_kv_heads, |
| 317 | + max_position=max_position, |
| 318 | + rope_theta=rope_theta, |
| 319 | + is_neox_style=is_neox_style, |
| 320 | + rope_scaling=config.rope_scaling, |
| 321 | + dtype=getattr(torch, args.dtype), |
| 322 | + seed=args.seed, |
| 323 | + warmup_iter=args.warmup_iter, |
| 324 | + benchmark_iter=args.benchmark_iter, |
| 325 | + csv_writer=csv_writer, |
| 326 | + ) |
| 327 | + |
| 328 | + print(f"Benchmark results saved to {csv_filename}") |
0 commit comments