Skip to content

Commit ec4fc2c

Browse files
happierpighappierpig
andauthored
feat: add warp-level persistent qk norm (#1843)
<!-- .github/pull_request_template.md --> ## 📌 Description Recent models are using QK normalization right before RoPE and core self-attention (e.g., Qwen-3, Wan). Existing RMSNorm implementation in FlashInfer falls short on optimal for: 1. Extra shared memory reduction step. 2. Do not support non-contiguous layout on the middle dimension. E.g., q maybe [batch_size, :num_qo_heads, head_dim]. This PR implements a persistent version of RMSNorm, where each head is unrolled with each warp. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: happierpig <[email protected]>
1 parent be130a7 commit ec4fc2c

File tree

4 files changed

+204
-19
lines changed

4 files changed

+204
-19
lines changed

csrc/norm.cu

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,59 @@ using tvm::ffi::Tensor;
2323

2424
void rmsnorm(Tensor output, Tensor input, Tensor weight, double eps, bool enable_pdl) {
2525
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
26+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
2627
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
2728
CHECK_DEVICE(input, weight);
28-
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
2929
CHECK_DIM(1, weight); // weight: (hidden_size)
30-
TVM_FFI_ICHECK_EQ(input->shape[1], weight->shape[0]);
31-
unsigned int batch_size = input->shape[0];
32-
unsigned int hidden_size = input->shape[1];
33-
TVM_FFI_ICHECK_EQ(output->shape[0], batch_size);
34-
TVM_FFI_ICHECK_EQ(output->shape[1], hidden_size);
35-
cudaSetDevice(input->device.device_id);
36-
const cudaStream_t stream = get_stream(input->device);
3730

38-
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] {
39-
cudaError_t status =
40-
norm::RMSNorm(static_cast<c_type*>(input->data), static_cast<c_type*>(weight->data),
41-
static_cast<c_type*>(output->data), batch_size, hidden_size,
42-
input->strides[0], output->strides[0], eps, enable_pdl, stream);
43-
TVM_FFI_ICHECK(status == cudaSuccess)
44-
<< "RMSNorm failed with error code " << cudaGetErrorString(status);
45-
return true;
46-
});
31+
auto input_ndim = input->ndim;
32+
if (input_ndim == 2) {
33+
// Normal RMSNorm: [batch_size, hidden_size]
34+
// Use CTA parallelization for better parallelism
35+
CHECK_DIM(2, output);
36+
TVM_FFI_ICHECK_EQ(input->shape[1], weight->shape[0]);
37+
unsigned int batch_size = input->shape[0];
38+
unsigned int hidden_size = input->shape[1];
39+
TVM_FFI_ICHECK_EQ(output->shape[0], batch_size);
40+
TVM_FFI_ICHECK_EQ(output->shape[1], hidden_size);
41+
cudaSetDevice(input->device.device_id);
42+
const cudaStream_t stream = get_stream(input->device);
43+
44+
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] {
45+
cudaError_t status =
46+
norm::RMSNorm(static_cast<c_type*>(input->data), static_cast<c_type*>(weight->data),
47+
static_cast<c_type*>(output->data), batch_size, hidden_size,
48+
input->strides[0], output->strides[0], eps, enable_pdl, stream);
49+
TVM_FFI_ICHECK(status == cudaSuccess)
50+
<< "RMSNorm failed with error code " << cudaGetErrorString(status);
51+
return true;
52+
});
53+
} else if (input_ndim == 3) {
54+
// QK RMSNorm: [batch_size, num_heads, head_dim]
55+
// Use warp-level parallization
56+
CHECK_DIM(3, output); // output: (batch_size, num_heads, hidden_size)
57+
TVM_FFI_ICHECK_EQ(input->shape[2], weight->shape[0]);
58+
unsigned int batch_size = input->shape[0];
59+
unsigned int num_heads = input->shape[1];
60+
unsigned int hidden_size = input->shape[2];
61+
TVM_FFI_ICHECK_EQ(output->shape[0], batch_size);
62+
TVM_FFI_ICHECK_EQ(output->shape[1], num_heads);
63+
TVM_FFI_ICHECK_EQ(output->shape[2], hidden_size);
64+
65+
cudaSetDevice(input->device.device_id);
66+
const cudaStream_t stream = get_stream(input->device);
67+
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] {
68+
cudaError_t status = norm::QKRMSNorm(
69+
static_cast<c_type*>(input->data), static_cast<c_type*>(weight->data),
70+
static_cast<c_type*>(output->data), batch_size, num_heads, hidden_size, input->strides[0],
71+
input->strides[1], output->strides[0], output->strides[1], eps, enable_pdl, stream);
72+
TVM_FFI_ICHECK(status == cudaSuccess)
73+
<< "QKRMSNorm failed with error code " << cudaGetErrorString(status);
74+
return true;
75+
});
76+
} else {
77+
TVM_FFI_ICHECK(false) << "Unsupported input dimension: " << input_ndim;
78+
}
4779
}
4880

4981
void fused_add_rmsnorm(Tensor input, Tensor residual, Tensor weight, double eps, bool enable_pdl) {

flashinfer/norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def rmsnorm(
5454
Parameters
5555
----------
5656
input: torch.Tensor
57-
Input tensor, shape (batch_size, hidden_size).
57+
Input tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
5858
weight: torch.Tensor
5959
Weight tensor, shape (hidden_size,).
6060
eps: float
@@ -68,7 +68,7 @@ def rmsnorm(
6868
Returns
6969
-------
7070
output: torch.Tensor
71-
Normalized tensor, shape (batch_size, hidden_size).
71+
Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
7272
"""
7373
if enable_pdl is None:
7474
enable_pdl = device_support_pdl(input.device)

include/flashinfer/norm.cuh

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,128 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
139139
return cudaSuccess;
140140
}
141141

142+
template <uint32_t VEC_SIZE, typename T>
143+
__global__ void QKRMSNormKernel(T* __restrict__ input, T* __restrict__ weight,
144+
T* __restrict__ output, const uint32_t d, const uint32_t batch_size,
145+
const uint32_t num_heads, const uint32_t stride_input_n,
146+
const uint32_t stride_input_h, const uint32_t stride_output_n,
147+
const uint32_t stride_output_h, float weight_bias, float eps) {
148+
const uint32_t num_blks = gridDim.x, num_warps = blockDim.y;
149+
const uint32_t num_workers = num_blks * num_warps; // unroll on warp-dim
150+
const uint32_t num_jobs = batch_size * num_heads;
151+
152+
const uint32_t bx = blockIdx.x;
153+
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
154+
const uint32_t worker_idx = bx * num_warps + ty;
155+
156+
constexpr uint32_t warp_size = 32;
157+
const uint32_t num_threads = warp_size;
158+
const uint32_t thread_id = tx;
159+
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
160+
161+
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
162+
asm volatile("griddepcontrol.wait;");
163+
#endif
164+
165+
for (uint32_t job_idx = worker_idx; job_idx < num_jobs; job_idx += num_workers) {
166+
// clear buffer
167+
float sum_sq = 0.f;
168+
169+
// map back to batch-idx and head-idx; layout [batch_size, num_heads, head_dim]
170+
const uint32_t batch_idx = job_idx / num_heads;
171+
const uint32_t head_idx = job_idx % num_heads;
172+
173+
for (uint32_t i = 0; i < rounds; i++) {
174+
vec_t<T, VEC_SIZE> input_vec;
175+
input_vec.fill(0.f);
176+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
177+
input_vec.load(input + batch_idx * stride_input_n + head_idx * stride_input_h +
178+
i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
179+
}
180+
#pragma unroll
181+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
182+
sum_sq += float(input_vec[j]) * float(input_vec[j]);
183+
}
184+
}
185+
186+
// only have warp reduce sum
187+
// no need for __syncwarps as shfl already sync
188+
#pragma unroll
189+
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
190+
sum_sq += math::shfl_xor_sync(sum_sq, offset);
191+
}
192+
193+
float rms_rcp = math::rsqrt(sum_sq / float(d) + eps);
194+
195+
for (uint32_t i = 0; i < rounds; i++) {
196+
vec_t<T, VEC_SIZE> input_vec;
197+
vec_t<T, VEC_SIZE> weight_vec;
198+
vec_t<T, VEC_SIZE> output_vec;
199+
input_vec.fill(0.f);
200+
weight_vec.fill(0.f);
201+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
202+
input_vec.load(input + batch_idx * stride_input_n + head_idx * stride_input_h +
203+
i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
204+
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
205+
}
206+
#pragma unroll
207+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
208+
output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j]));
209+
}
210+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
211+
output_vec.store(output + batch_idx * stride_output_n + head_idx * stride_output_h +
212+
i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
213+
}
214+
}
215+
}
216+
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
217+
asm volatile("griddepcontrol.launch_dependents;");
218+
#endif
219+
}
220+
221+
template <typename T>
222+
cudaError_t QKRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t num_heads,
223+
uint32_t d, uint32_t stride_input_n, uint32_t stride_input_h,
224+
uint32_t stride_output_n, uint32_t stride_output_h, float eps = 1e-5,
225+
bool enable_pdl = false, cudaStream_t stream = 0) {
226+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
227+
const uint32_t num_warps = 4;
228+
const uint32_t smem_size = 0;
229+
230+
float weight_bias = 0.f;
231+
232+
cudaLaunchConfig_t config;
233+
config.dynamicSmemBytes = smem_size;
234+
config.stream = stream;
235+
cudaLaunchAttribute attrs[1];
236+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
237+
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
238+
config.numAttrs = 1;
239+
config.attrs = attrs;
240+
241+
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
242+
auto kernel = QKRMSNormKernel<VEC_SIZE, T>;
243+
244+
// calculate launching blocks
245+
int num_blocks_per_sm = 0, num_sms = 0, dev_id = 0;
246+
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
247+
num_warps * 32, smem_size));
248+
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
249+
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
250+
251+
dim3 nblks(num_blocks_per_sm * num_sms);
252+
dim3 nthrs(32, num_warps);
253+
config.gridDim = nblks;
254+
config.blockDim = nthrs;
255+
256+
// execute kernel
257+
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, batch_size,
258+
num_heads, stride_input_n, stride_input_h,
259+
stride_output_n, stride_output_h, weight_bias, eps));
260+
});
261+
return cudaSuccess;
262+
}
263+
142264
template <uint32_t VEC_SIZE, typename T>
143265
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
144266
T* __restrict__ weight, const uint32_t d,

tests/utils/test_norm.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,37 @@ def test_norm(batch_size, hidden_size, dtype, specify_out, enable_pdl, contiguou
9393
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
9494

9595

96+
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
97+
@pytest.mark.parametrize("num_heads", [4, 7, 16])
98+
@pytest.mark.parametrize("head_dim", [64, 128, 256, 512])
99+
@pytest.mark.parametrize("dtype", [torch.float16])
100+
@pytest.mark.parametrize("specify_out", [True, False])
101+
@pytest.mark.parametrize("enable_pdl", [True, False])
102+
@pytest.mark.parametrize("contiguous", [True, False])
103+
def test_qknorm(
104+
batch_size, num_heads, head_dim, dtype, specify_out, enable_pdl, contiguous
105+
):
106+
if contiguous:
107+
x = torch.randn(batch_size, num_heads, head_dim).to(0).to(dtype)
108+
else:
109+
x = torch.randn(batch_size, num_heads * 2, head_dim, device="cuda").to(dtype)
110+
x = x[:, :num_heads, :head_dim]
111+
112+
if enable_pdl and not device_support_pdl(x.device):
113+
pytest.skip("PDL is only available for Hopper and later GPUs")
114+
115+
w = torch.randn(head_dim).to(0).to(dtype)
116+
117+
y_ref = llama_rms_norm(x, w)
118+
if specify_out:
119+
y = torch.empty_like(x)
120+
flashinfer.norm.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
121+
else:
122+
y = flashinfer.norm.rmsnorm(x, w, enable_pdl=enable_pdl)
123+
124+
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
125+
126+
96127
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
97128
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
98129
@pytest.mark.parametrize("dtype", [torch.float16])

0 commit comments

Comments
 (0)