Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 7937009

Browse files
ProExpertProgmgoin
andauthored
[Kernel] Replaced blockReduce[...] functions with cub::BlockReduce (vllm-project#7233)
Co-authored-by: Michael Goin <[email protected]>
1 parent 9984605 commit 7937009

File tree

8 files changed

+237
-116
lines changed

8 files changed

+237
-116
lines changed

.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ tasks:
44
- name: "gsm8k"
55
metrics:
66
- name: "exact_match,strict-match"
7-
value: 0.409
7+
value: 0.419
88
- name: "exact_match,flexible-extract"
9-
value: 0.406
9+
value: 0.416
1010
limit: 1000
1111
num_fewshot: 5
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import random
2+
import time
3+
4+
import torch
5+
6+
from vllm.model_executor.layers.layernorm import RMSNorm
7+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
8+
9+
10+
@torch.inference_mode()
11+
def main(num_tokens: int,
12+
hidden_size: int,
13+
add_residual: bool,
14+
dtype: torch.dtype,
15+
seed: int = 0,
16+
do_profile: bool = False,
17+
num_warmup_iters: int = 5,
18+
num_iters: int = 100) -> None:
19+
random.seed(seed)
20+
torch.random.manual_seed(seed)
21+
if torch.cuda.is_available():
22+
torch.cuda.manual_seed(seed)
23+
torch.set_default_device("cuda")
24+
25+
layer = RMSNorm(hidden_size).to(dtype=dtype)
26+
layer.weight.data.normal_(mean=1.0, std=0.1)
27+
scale = 1 / (2 * hidden_size)
28+
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
29+
x *= scale
30+
residual = torch.randn_like(x) * scale if add_residual else None
31+
32+
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
33+
torch.cuda.synchronize()
34+
if profile:
35+
torch.cuda.cudart().cudaProfilerStart()
36+
start_time = time.perf_counter()
37+
38+
for _ in range(num_iters):
39+
layer(x, residual)
40+
torch.cuda.synchronize()
41+
42+
end_time = time.perf_counter()
43+
if profile:
44+
torch.cuda.cudart().cudaProfilerStart()
45+
return (end_time - start_time) / num_iters
46+
47+
# Warmup.
48+
print("Warming up...")
49+
run_benchmark = run_cuda_benchmark
50+
run_benchmark(num_iters=num_warmup_iters, profile=False)
51+
52+
# Benchmark.
53+
if do_profile:
54+
latency = run_benchmark(num_iters=1, profile=True)
55+
else:
56+
latency = run_benchmark(num_iters=num_iters, profile=False)
57+
print(f"Kernel running time: {latency * 1000000:.3f} us")
58+
59+
60+
if __name__ == '__main__':
61+
parser = FlexibleArgumentParser(
62+
description="Benchmark the layernorm kernel.")
63+
parser.add_argument("--num-tokens", type=int, default=4096)
64+
parser.add_argument("--hidden-size", type=int, default=8192)
65+
parser.add_argument("--add-residual", action="store_true")
66+
parser.add_argument("--dtype",
67+
type=str,
68+
choices=["half", "bfloat16", "float"],
69+
default="half")
70+
parser.add_argument("--seed", type=int, default=0)
71+
parser.add_argument("--profile", action="store_true")
72+
parser.add_argument("--num-warmup-iters", type=int, default=5)
73+
parser.add_argument("--num-iters",
74+
type=int,
75+
default=100,
76+
help="Number of benchmark iterations. "
77+
"If --profile is set, this number is ignored")
78+
79+
args = parser.parse_args()
80+
print(args)
81+
82+
main(num_tokens=args.num_tokens,
83+
hidden_size=args.hidden_size,
84+
add_residual=args.add_residual,
85+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
86+
seed=args.seed,
87+
do_profile=args.profile,
88+
num_warmup_iters=args.num_warmup_iters,
89+
num_iters=args.num_iters)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import random
2+
import time
3+
4+
import torch
5+
6+
from vllm import _custom_ops as ops
7+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
8+
9+
10+
@torch.inference_mode()
11+
def main(num_tokens: int,
12+
hidden_size: int,
13+
static_scale: bool,
14+
quant_dtype: torch.dtype,
15+
dtype: torch.dtype,
16+
seed: int = 0,
17+
do_profile: bool = False,
18+
num_warmup_iters: int = 5,
19+
num_iters: int = 100) -> None:
20+
random.seed(seed)
21+
torch.random.manual_seed(seed)
22+
if torch.cuda.is_available():
23+
torch.cuda.manual_seed(seed)
24+
torch.set_default_device("cuda")
25+
26+
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
27+
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
28+
29+
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
30+
torch.cuda.synchronize()
31+
if profile:
32+
torch.cuda.cudart().cudaProfilerStart()
33+
start_time = time.perf_counter()
34+
35+
for _ in range(num_iters):
36+
if quant_dtype == torch.int8:
37+
ops.scaled_int8_quant(x, scale)
38+
else:
39+
ops.scaled_fp8_quant(x, scale)
40+
torch.cuda.synchronize()
41+
42+
end_time = time.perf_counter()
43+
if profile:
44+
torch.cuda.cudart().cudaProfilerStart()
45+
return (end_time - start_time) / num_iters
46+
47+
# Warmup.
48+
print("Warming up...")
49+
run_benchmark = run_cuda_benchmark
50+
run_benchmark(num_iters=num_warmup_iters, profile=False)
51+
52+
# Benchmark.
53+
if do_profile:
54+
latency = run_benchmark(num_iters=1, profile=True)
55+
else:
56+
latency = run_benchmark(num_iters=num_iters, profile=False)
57+
print(f"Kernel running time: {latency * 1000000:.3f} us")
58+
59+
60+
if __name__ == '__main__':
61+
62+
def to_torch_dtype(dt):
63+
if dt == "int8":
64+
return torch.int8
65+
if dt == "fp8":
66+
return torch.float8_e4m3fn
67+
raise ValueError(f"Unsupported dtype: {dt}")
68+
69+
parser = FlexibleArgumentParser(
70+
description="Benchmark the quantization (fp8 or int8) kernel.")
71+
parser.add_argument("--num-tokens", type=int, default=4096)
72+
parser.add_argument("--hidden-size", type=int, default=8192)
73+
parser.add_argument("--static-scale", action="store_true")
74+
parser.add_argument("--quant-dtype",
75+
type=str,
76+
choices=["fp8", "int8"],
77+
default="int8")
78+
parser.add_argument("--dtype",
79+
type=str,
80+
choices=["half", "bfloat16", "float"],
81+
default="half")
82+
83+
parser.add_argument("--seed", type=int, default=0)
84+
parser.add_argument("--profile", action="store_true")
85+
parser.add_argument("--num-warmup-iters", type=int, default=5)
86+
parser.add_argument("--num-iters",
87+
type=int,
88+
default=100,
89+
help="Number of benchmark iterations. "
90+
"If --profile is set, this number is ignored")
91+
92+
args = parser.parse_args()
93+
print(args)
94+
95+
main(num_tokens=args.num_tokens,
96+
hidden_size=args.hidden_size,
97+
static_scale=args.static_scale,
98+
quant_dtype=to_torch_dtype(args.quant_dtype),
99+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
100+
seed=args.seed,
101+
do_profile=args.profile,
102+
num_warmup_iters=args.num_warmup_iters,
103+
num_iters=args.num_iters)

csrc/layernorm_kernels.cu

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
#include <c10/cuda/CUDAGuard.h>
44

55
#include "dispatch_utils.h"
6-
#include "reduction_utils.cuh"
76
#ifndef USE_ROCM
87
#include <cuda_bf16.h>
98
#include <cuda_fp16.h>
9+
#include <cub/util_type.cuh>
10+
#include <cub/cub.cuh>
1011
#else
1112
#include <hip/hip_bf16.h>
1213
#include <hip/hip_fp16.h>
14+
#include <hipcub/util_type.hpp>
15+
#include <hipcub/hipcub.hpp>
1316

1417
using __nv_bfloat16 = __hip_bfloat16;
1518
using __nv_bfloat162 = __hip_bfloat162;
@@ -31,7 +34,11 @@ __global__ void rms_norm_kernel(
3134
const float x = (float)input[blockIdx.x * hidden_size + idx];
3235
variance += x * x;
3336
}
34-
variance = blockReduceSum<float>(variance);
37+
38+
using BlockReduce = cub::BlockReduce<float, 1024>;
39+
__shared__ typename BlockReduce::TempStorage reduceStore;
40+
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
41+
3542
if (threadIdx.x == 0) {
3643
s_variance = rsqrtf(variance / hidden_size + epsilon);
3744
}
@@ -228,12 +235,11 @@ fused_add_rms_norm_kernel(
228235
variance += temp.sum_squares();
229236
residual_v[id] = temp;
230237
}
231-
/* Keep the following if-else block in sync with the
232-
calculation of max_block_size in fused_add_rms_norm */
233-
if (num_tokens < 256) {
234-
variance = blockReduceSum<float, 1024>(variance);
235-
} else
236-
variance = blockReduceSum<float, 256>(variance);
238+
239+
using BlockReduce = cub::BlockReduce<float, 1024>;
240+
__shared__ typename BlockReduce::TempStorage reduceStore;
241+
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
242+
237243
if (threadIdx.x == 0) {
238244
s_variance = rsqrtf(variance / hidden_size + epsilon);
239245
}
@@ -268,12 +274,11 @@ fused_add_rms_norm_kernel(
268274
variance += x * x;
269275
residual[blockIdx.x * hidden_size + idx] = z;
270276
}
271-
/* Keep the following if-else block in sync with the
272-
calculation of max_block_size in fused_add_rms_norm */
273-
if (num_tokens < 256) {
274-
variance = blockReduceSum<float, 1024>(variance);
275-
} else
276-
variance = blockReduceSum<float, 256>(variance);
277+
278+
using BlockReduce = cub::BlockReduce<float, 1024>;
279+
__shared__ typename BlockReduce::TempStorage reduceStore;
280+
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
281+
277282
if (threadIdx.x == 0) {
278283
s_variance = rsqrtf(variance / hidden_size + epsilon);
279284
}

csrc/quantization/compressed_tensors/int8_quant_kernels.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
#include <cmath>
44

55
#include "../../dispatch_utils.h"
6-
#include "../../reduction_utils.cuh"
6+
7+
#ifndef USE_ROCM
8+
#include <cub/util_type.cuh>
9+
#include <cub/cub.cuh>
10+
#else
11+
#include <hipcub/util_type.hpp>
12+
#include <hipcub/hipcub.hpp>
13+
#endif
714

815
static inline __device__ int8_t float_to_int8_rn(float x) {
916
#ifdef USE_ROCM
@@ -55,7 +62,10 @@ __global__ void dynamic_scaled_int8_quant_kernel(
5562
absmax_val = val > absmax_val ? val : absmax_val;
5663
}
5764

58-
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
65+
using BlockReduce = cub::BlockReduce<float, 1024>;
66+
__shared__ typename BlockReduce::TempStorage reduceStorage;
67+
float const block_absmax_val_maybe =
68+
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
5969
__shared__ float block_absmax_val;
6070
if (tid == 0) {
6171
block_absmax_val = block_absmax_val_maybe;

csrc/quantization/fp8/common.cu

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
#include "cuda_compat.h"
88
#include "dispatch_utils.h"
99

10-
#include "../../reduction_utils.cuh"
10+
#ifndef USE_ROCM
11+
#include <cub/util_type.cuh>
12+
#include <cub/cub.cuh>
13+
#else
14+
#include <hipcub/util_type.hpp>
15+
#include <hipcub/hipcub.hpp>
16+
#endif
1117

1218
#ifndef USE_ROCM
1319
using FP8_TYPE = c10::Float8_e4m3fn;
@@ -215,7 +221,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
215221
}
216222
}
217223

218-
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
224+
using BlockReduce = cub::BlockReduce<float, 1024>;
225+
__shared__ typename BlockReduce::TempStorage reduceStorage;
226+
float const block_absmax_val_maybe =
227+
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
219228
__shared__ float token_scale;
220229
if (tid == 0) {
221230
if (scale_ub) {

0 commit comments

Comments
 (0)