Skip to content

Commit 72ea489

Browse files
tjtanaajingyu
authored andcommitted
[FEAT] [Performance] Add triton mrope to replace the torch code path (vllm-project#22375)
Signed-off-by: tjtanaa <[email protected]> Signed-off-by: jingyu <[email protected]>
1 parent 2a5985d commit 72ea489

File tree

3 files changed

+766
-0
lines changed

3 files changed

+766
-0
lines changed

benchmarks/kernels/benchmark_mrope.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models).
5+
# It generates test data, runs benchmarks, and saves results to a CSV file.
6+
#
7+
# The CSV file (named with current date/time) contains these columns:
8+
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
9+
# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
10+
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
11+
# speedup
12+
#
13+
# == Usage Examples ==
14+
#
15+
# Single model benchmark:
16+
# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \
17+
# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
18+
#
19+
# All models benchmark:
20+
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
21+
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
22+
#
23+
# All models with different TP sizes:
24+
# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \
25+
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
26+
#
27+
# All models with different token counts:
28+
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
29+
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384
30+
import csv
31+
import os
32+
import time
33+
from datetime import datetime
34+
from typing import Any
35+
36+
import numpy as np
37+
import torch
38+
39+
from vllm.model_executor.layers.rotary_embedding import get_rope
40+
from vllm.platforms import current_platform
41+
from vllm.transformers_utils.config import get_config
42+
from vllm.utils import FlexibleArgumentParser
43+
44+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45+
46+
47+
def generate_test_data(
48+
num_tokens: int,
49+
num_q_heads: int,
50+
num_kv_heads: int,
51+
head_size: int,
52+
max_position_embeddings: int,
53+
dtype: torch.dtype,
54+
device: torch.device,
55+
):
56+
"""Generate test data for given configuration."""
57+
# Create 2D positions (3, num_tokens) for multimodal case
58+
positions = torch.randint(
59+
0, max_position_embeddings // 4, (3, num_tokens), device=device
60+
)
61+
62+
# Create query and key tensors
63+
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
64+
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
65+
66+
return positions, query, key
67+
68+
69+
def calculate_stats(times: list[float]) -> dict[str, float]:
70+
"""Calculate statistics from a list of times."""
71+
times_array = np.array(times)
72+
return {
73+
"mean": np.mean(times_array),
74+
"median": np.median(times_array),
75+
"p99": np.percentile(times_array, 99),
76+
"min": np.min(times_array),
77+
"max": np.max(times_array),
78+
}
79+
80+
81+
def benchmark_mrope(
82+
model_name: str,
83+
num_tokens: int,
84+
head_dim: int,
85+
tp_size: int,
86+
num_heads: int,
87+
num_kv_heads: int,
88+
max_position: int = 8192,
89+
rope_theta: float = 10000,
90+
is_neox_style: bool = True,
91+
rope_scaling: dict[str, Any] = None,
92+
dtype: torch.dtype = torch.bfloat16,
93+
seed: int = 0,
94+
warmup_iter: int = 10,
95+
benchmark_iter: int = 100,
96+
csv_writer=None,
97+
):
98+
current_platform.seed_everything(seed)
99+
torch.set_default_device(device)
100+
# the parameters to compute the q k v size based on tp_size
101+
mrope_helper_class = get_rope(
102+
head_size=head_dim,
103+
rotary_dim=head_dim,
104+
max_position=max_position,
105+
base=rope_theta,
106+
is_neox_style=is_neox_style,
107+
rope_scaling=rope_scaling,
108+
dtype=dtype,
109+
).to(device=device)
110+
111+
print(80 * "=")
112+
print(
113+
f"Evaluating model: {model_name} "
114+
f"with tp_size: {tp_size} "
115+
f"and num_tokens: {num_tokens}, "
116+
f"dtype: {dtype}"
117+
)
118+
119+
# create q k v input tensors
120+
# create rotary pos emb input tensors
121+
positions, query, key = generate_test_data(
122+
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
123+
)
124+
125+
# Warm up
126+
for _ in range(warmup_iter):
127+
mrope_helper_class.forward_native(
128+
positions,
129+
query.clone(),
130+
key.clone(),
131+
)
132+
133+
mrope_helper_class.forward_cuda(
134+
positions,
135+
query.clone(),
136+
key.clone(),
137+
)
138+
139+
torch.cuda.synchronize()
140+
141+
# Time reference implementation
142+
torch_times = []
143+
for _ in range(benchmark_iter):
144+
query_clone = query.clone()
145+
key_clone = key.clone()
146+
torch.cuda.synchronize()
147+
start_time = time.time()
148+
149+
mrope_helper_class.forward_native(
150+
positions,
151+
query_clone,
152+
key_clone,
153+
)
154+
155+
torch.cuda.synchronize()
156+
torch_times.append(time.time() - start_time)
157+
158+
# Time triton kernel implementation
159+
triton_times = []
160+
for _ in range(benchmark_iter):
161+
query_clone = query.clone()
162+
key_clone = key.clone()
163+
torch.cuda.synchronize()
164+
start_time = time.time()
165+
mrope_helper_class.forward_cuda(
166+
positions,
167+
query_clone,
168+
key_clone,
169+
)
170+
torch.cuda.synchronize()
171+
triton_times.append(time.time() - start_time)
172+
173+
# Calculate statistics
174+
torch_stats = calculate_stats(torch_times)
175+
triton_stats = calculate_stats(triton_times)
176+
print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):")
177+
178+
print(
179+
f"Torch implementation: "
180+
f"mean={torch_stats['mean']:.8f}s, "
181+
f"median={torch_stats['median']:.8f}s, "
182+
f"p99={torch_stats['p99']:.8f}s"
183+
)
184+
185+
print(
186+
f"Triton implementation: "
187+
f"mean={triton_stats['mean']:.8f}s, "
188+
f"median={triton_stats['median']:.8f}s, "
189+
f"p99={triton_stats['p99']:.8f}s"
190+
)
191+
192+
print(
193+
f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x"
194+
)
195+
196+
# Write to CSV
197+
if csv_writer:
198+
row = [
199+
model_name,
200+
tp_size,
201+
num_tokens,
202+
num_heads,
203+
num_kv_heads,
204+
head_dim,
205+
max_position,
206+
rope_theta,
207+
is_neox_style,
208+
str(rope_scaling),
209+
str(dtype).split(".")[-1],
210+
torch_stats["mean"],
211+
torch_stats["median"],
212+
torch_stats["p99"],
213+
torch_stats["min"],
214+
torch_stats["max"],
215+
triton_stats["mean"],
216+
triton_stats["median"],
217+
triton_stats["p99"],
218+
triton_stats["min"],
219+
triton_stats["max"],
220+
torch_stats["mean"] / triton_stats["mean"], # speedup
221+
]
222+
csv_writer.writerow(row)
223+
224+
return torch_stats, triton_stats
225+
226+
227+
if __name__ == "__main__":
228+
parser = FlexibleArgumentParser(
229+
description="Benchmark the rotary embedding kernels."
230+
)
231+
parser.add_argument("--model-name", type=str, default="")
232+
parser.add_argument("--tp-size", type=int, default=1)
233+
parser.add_argument("--warmup-iter", type=int, default=10)
234+
parser.add_argument("--benchmark-iter", type=int, default=100)
235+
parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16")
236+
parser.add_argument("--seed", type=int, default=0)
237+
parser.add_argument("--num-tokens", type=int, nargs="+", required=False)
238+
parser.add_argument("--trust-remote-code", action="store_true")
239+
parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv")
240+
args = parser.parse_args()
241+
print(args)
242+
243+
# Create CSV file for results
244+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
245+
csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv"
246+
247+
with open(csv_filename, "w", newline="") as csvfile:
248+
csv_writer = csv.writer(csvfile)
249+
# Write header
250+
header = [
251+
"model_name",
252+
"tp_size",
253+
"num_tokens",
254+
"num_heads",
255+
"num_kv_heads",
256+
"head_dim",
257+
"max_position",
258+
"rope_theta",
259+
"is_neox_style",
260+
"rope_scaling",
261+
"dtype",
262+
"torch_mean",
263+
"torch_median",
264+
"torch_p99",
265+
"torch_min",
266+
"torch_max",
267+
"triton_mean",
268+
"triton_median",
269+
"triton_p99",
270+
"triton_min",
271+
"triton_max",
272+
"speedup",
273+
]
274+
csv_writer.writerow(header)
275+
276+
model_tp_dict = {}
277+
if args.model_name == "":
278+
model_tp_dict = {
279+
"Qwen/Qwen2-VL-2B-Instruct": [1],
280+
"Qwen/Qwen2-VL-7B-Instruct": [1],
281+
"Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8],
282+
"Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8],
283+
"Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8],
284+
"Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8],
285+
}
286+
else:
287+
model_tp_dict[args.model_name] = [args.tp_size]
288+
289+
if args.num_tokens is None:
290+
num_tokens_list = [2**i for i in range(0, 18)]
291+
else:
292+
num_tokens_list = args.num_tokens
293+
294+
for model_name, tp_list in model_tp_dict.items():
295+
config = get_config(model_name, trust_remote_code=args.trust_remote_code)
296+
for tp_size in tp_list:
297+
# get the model config
298+
total_num_kv_heads = config.num_key_value_heads
299+
total_num_heads = config.num_attention_heads
300+
num_heads = total_num_heads // tp_size
301+
num_kv_heads = max(1, total_num_kv_heads // tp_size)
302+
head_dim = config.hidden_size // total_num_heads
303+
q_size = num_heads * head_dim
304+
kv_size = num_kv_heads * head_dim
305+
is_neox_style = True
306+
rope_theta = config.rope_theta
307+
max_position = config.max_position_embeddings
308+
309+
for num_tokens in num_tokens_list:
310+
benchmark_mrope(
311+
model_name=model_name,
312+
num_tokens=num_tokens,
313+
head_dim=head_dim,
314+
tp_size=tp_size,
315+
num_heads=num_heads,
316+
num_kv_heads=num_kv_heads,
317+
max_position=max_position,
318+
rope_theta=rope_theta,
319+
is_neox_style=is_neox_style,
320+
rope_scaling=config.rope_scaling,
321+
dtype=getattr(torch, args.dtype),
322+
seed=args.seed,
323+
warmup_iter=args.warmup_iter,
324+
benchmark_iter=args.benchmark_iter,
325+
csv_writer=csv_writer,
326+
)
327+
328+
print(f"Benchmark results saved to {csv_filename}")

0 commit comments

Comments
 (0)