Skip to content

Commit 43d761e

Browse files
authored
[benchmarks][vllm] Paged attention benchmark (#5348)
Closes #5257 I also started reporting gbps to the database because many benchmarks are memory bound
1 parent f3a0aec commit 43d761e

File tree

4 files changed

+1321
-22
lines changed

4 files changed

+1321
-22
lines changed

.github/workflows/third-party-benchmarks.yml

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,13 @@ jobs:
9090
./scripts/test-triton.sh --install-sglang --skip-pip-install --skip-pytorch-install
9191
cd benchmarks/third_party/sglang
9292
python scaled_mm_benchmark.py --reports $REPORTS
93-
python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-int8-report.csv --tag $TAG --benchmark scaled-mm-int8 --param_cols="M,N,K" --bgroup sglang
93+
python ../vllm/transform_results.py \
94+
$REPORTS/scaled_mm_benchmark.csv \
95+
$REPORTS/scaled-mm-int8-report.csv \
96+
--tag $TAG \
97+
--bgroup sglang \
98+
--benchmark scaled-mm-int8 \
99+
--param_cols="M,N,K"
94100
95101
- name: Run sglang benchmark with fp8
96102
if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'sglang')) }}
@@ -99,29 +105,68 @@ jobs:
99105
100106
cd benchmarks/third_party/sglang
101107
FP8="1" python scaled_mm_benchmark.py --reports $REPORTS
102-
python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-fp8-report.csv --tag $TAG --benchmark scaled-mm-fp8 --param_cols="M,N,K" --bgroup sglang
108+
python ../vllm/transform_results.py \
109+
$REPORTS/scaled_mm_benchmark.csv \
110+
$REPORTS/scaled-mm-fp8-report.csv \
111+
--tag $TAG \
112+
--bgroup sglang \
113+
--benchmark scaled-mm-fp8 \
114+
--param_cols="M,N,K"
103115
104-
- name: Run vllm benchmarks bf16
116+
- name: Install vllm
117+
id: install-vllm
105118
if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }}
106119
run: |
107120
source ./scripts/capture-hw-details.sh
108-
109121
./scripts/test-triton.sh --install-vllm --skip-pip-install --skip-pytorch-install
122+
123+
- name: Run vllm unified attention bf16
124+
if: ${{ steps.install-vllm.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }}
125+
run: |
126+
source ./scripts/capture-hw-details.sh
127+
128+
cd benchmarks/third_party/vllm
129+
python unified_attention_benchmark.py --reports $REPORTS
130+
python transform_results.py \
131+
$REPORTS/unified-attention-performance.csv \
132+
$REPORTS/unified-attention-report.csv \
133+
--tag $TAG \
134+
--bgroup "vllm" \
135+
--benchmark "unified-attn-bf16" \
136+
--param_cols "q_heads,k_heads,head_size,dtype,qdtype,seq_lens,sliding_window,soft_cap,num_blocks,block_size"
137+
138+
- name: Run vllm batched moe bf16
139+
if: ${{ steps.install-vllm.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }}
140+
run: |
141+
source ./scripts/capture-hw-details.sh
142+
110143
cp -r vllm/tests benchmarks/third_party/vllm/tests
111144
112145
cd benchmarks/third_party/vllm
113146
python batched_moe_benchmark.py --reports $REPORTS
114-
python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-bf16-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" --bgroup vllm
147+
python transform_results.py \
148+
$REPORTS/moe-gemm-performance.csv \
149+
$REPORTS/moe-gemm-report.csv \
150+
--tag $TAG \
151+
--bgroup vllm \
152+
--benchmark moe-bf16-benchmark \
153+
--param_cols="num_experts,max_tokens_per_expert,K,N"
115154
116155
117-
- name: Run vllm benchmarks fp8
118-
if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }}
156+
- name: Run vllm batched moe fp8
157+
if: ${{ steps.install-vllm.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }}
119158
run: |
120159
source ./scripts/capture-hw-details.sh
121160
122161
cd benchmarks/third_party/vllm
123162
FP8="1" python batched_moe_benchmark.py --reports $REPORTS
124-
python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-fp8-report.csv --tag $TAG --benchmark moe-fp8-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" --bgroup vllm
163+
python transform_results.py \
164+
$REPORTS/moe-gemm-performance.csv \
165+
$REPORTS/moe-gemm-fp8-report.csv \
166+
--tag $TAG \
167+
--bgroup vllm \
168+
--benchmark moe-fp8-benchmark \
169+
--param_cols="num_experts,max_tokens_per_expert,K,N"
125170
126171
127172
- name: Run Liger-Kernel benchmarks
@@ -136,7 +181,10 @@ jobs:
136181
bash benchmarks/third_party/liger/run_benchmarks.sh || RET_CODE=$?
137182
138183
cp Liger-Kernel/benchmark/data/all_benchmark_data.csv $REPORTS/liger-raw.csv
139-
python benchmarks/third_party/liger/transform.py $REPORTS/liger-raw.csv $REPORTS/liger-report.csv --tag $TAG
184+
python benchmarks/third_party/liger/transform.py \
185+
$REPORTS/liger-raw.csv \
186+
$REPORTS/liger-report.csv \
187+
--tag $TAG
140188
141189
# Return the captured return code at the end
142190
exit "$RET_CODE"

benchmarks/third_party/vllm/batched_moe_benchmark.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,11 +622,20 @@ def triton_fn():
622622
# Calculate performance metrics
623623
# Memory bandwidth: A (E*M*K*2) + B (E*K*N*2) + C (E*M*N*4) bytes
624624
# Compute: E * M * N * K * 2 FLOPs (multiply-add)
625+
num_activated_experts = num_expert_tokens.ne(0).sum().item()
626+
num_tokens = num_expert_tokens.sum().item()
625627

626628
def gbps(ms):
627629
n_bytes = 1 if fp8 else 2
628-
total_bytes = num_experts * (max_tokens_per_expert * K * n_bytes + K * N * n_bytes +
629-
max_tokens_per_expert * N * 2)
630+
# In practice due to the uniform distribution of lengths, on average half of the tokens are used,
631+
# let's take that into account
632+
total_bytes = (
633+
# B matrix, we only have to load activated experts
634+
num_activated_experts * (K * N * n_bytes) +
635+
# A matrix - activations, we only load part of tokens
636+
num_tokens * K * n_bytes +
637+
# C matrix - outputs, we only load/store part of tokens
638+
num_tokens * N * 2)
630639
return total_bytes * (1e-9) / (ms * 1e-3)
631640

632641
def tflops(ms):

benchmarks/third_party/vllm/transform_results.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,38 @@ def parse_csv(csv_file_path, tag, bench_group, benchmark, param_cols):
3131
run_uuid = uuid.uuid4().hex
3232
current_datetime = datetime.now().isoformat()
3333

34-
# Create params for all rows vectorized
35-
df['params'] = df.apply(lambda row: json.dumps({p: int(row[p]) for p in param_cols}), axis=1)
34+
def serialize_params(row):
35+
param2val = {}
36+
for p in param_cols:
37+
try:
38+
param2val[p] = int(row[p])
39+
except ValueError:
40+
param2val[p] = str(row[p])
41+
return json.dumps(param2val)
3642

37-
# Define compiler columns
38-
compilers = [('triton', 'triton-TFlops'), ('pytorch', 'pytorch-TFlops'), ('triton-td', 'triton-td-TFlops')]
43+
df['params'] = df.apply(serialize_params, axis=1)
44+
45+
compilers = ['pytorch', 'triton', 'triton-td']
3946

40-
# Create list of dataframes for each compiler
4147
dfs = []
42-
for compiler_name, tflops_col in compilers:
43-
if tflops_col in df.columns:
48+
for compiler_name in compilers:
49+
for value_name in ['TFlops', 'GB/s']:
50+
col = f'{compiler_name}-{value_name}'
51+
if col not in df.columns:
52+
continue
4453
# Filter out NaN values
45-
valid_rows = df[df[tflops_col].notna()].copy()
54+
valid_rows = df[df[col].notna()].copy()
4655
if len(valid_rows) > 0:
4756
valid_rows['run_uuid'] = run_uuid
4857
valid_rows['ts'] = current_datetime
4958
valid_rows['benchmark_group'] = bench_group
5059
valid_rows['benchmark'] = benchmark
5160
valid_rows['compiler'] = compiler_name
52-
valid_rows['value_name'] = 'tflops'
53-
valid_rows['value'] = valid_rows[tflops_col].astype(float)
61+
# GB/s -> gbps
62+
valid_rows['value_name'] = value_name.lower().replace('/', 'p')
63+
valid_rows['value'] = valid_rows[col].astype(float)
5464
valid_rows['tag'] = tag
5565

56-
# Select only needed columns
5766
result_df = valid_rows[[
5867
'run_uuid', 'ts', 'benchmark_group', 'benchmark', 'compiler', 'value_name', 'value', 'params', 'tag'
5968
]]

0 commit comments

Comments
 (0)