Skip to content

Commit 7ce00e5

Browse files
authored
support qk norm (#3145)
1 parent 4a10e29 commit 7ce00e5

17 files changed

+764
-174
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
7373
const paddle::optional<paddle::Tensor>& out_linear_shifts,
7474
const paddle::optional<paddle::Tensor>& out_linear_smooths,
7575
const paddle::optional<paddle::Tensor>& kv_signal_data,
76+
const paddle::optional<paddle::Tensor>& q_norm_weight,
77+
const paddle::optional<paddle::Tensor>& k_norm_weight,
78+
const float rms_norm_eps,
7679
const std::string& cache_quant_type_str,
7780
const bool use_neox_rotary_style,
7881
const bool rope_3d,
@@ -223,7 +226,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
223226
main_stream,
224227
&qkv_out,
225228
const_cast<paddle::Tensor*>(&key_cache),
226-
const_cast<paddle::Tensor*>(&value_cache));
229+
const_cast<paddle::Tensor*>(&value_cache),
230+
q_norm_weight,
231+
k_norm_weight,
232+
rms_norm_eps);
227233
};
228234

229235
if (qkv_out_scales) {
@@ -339,7 +345,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
339345
exec_stream,
340346
&qkv_out,
341347
const_cast<paddle::Tensor*>(&key_cache),
342-
const_cast<paddle::Tensor*>(&value_cache));
348+
const_cast<paddle::Tensor*>(&value_cache),
349+
q_norm_weight,
350+
k_norm_weight,
351+
rms_norm_eps);
343352
} else {
344353
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
345354
meta_data,
@@ -363,7 +372,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
363372
exec_stream,
364373
&qkv_out,
365374
const_cast<paddle::Tensor*>(&key_cache),
366-
const_cast<paddle::Tensor*>(&value_cache));
375+
const_cast<paddle::Tensor*>(&value_cache),
376+
q_norm_weight,
377+
k_norm_weight,
378+
rms_norm_eps);
367379
}
368380
}
369381

@@ -430,6 +442,9 @@ std::vector<paddle::Tensor> AppendAttention(
430442
const paddle::optional<paddle::Tensor>& out_linear_shifts,
431443
const paddle::optional<paddle::Tensor>& out_linear_smooths,
432444
const paddle::optional<paddle::Tensor>& kv_signal_data,
445+
const paddle::optional<paddle::Tensor>& q_norm_weight,
446+
const paddle::optional<paddle::Tensor>& k_norm_weight,
447+
const float rms_norm_eps,
433448
const std::string& compute_dtype,
434449
const std::string& cache_quant_type_str,
435450
const bool use_neox_rotary_style,
@@ -500,6 +515,9 @@ std::vector<paddle::Tensor> AppendAttention(
500515
out_linear_shifts,
501516
out_linear_smooths,
502517
kv_signal_data,
518+
q_norm_weight,
519+
k_norm_weight,
520+
rms_norm_eps,
503521
cache_quant_type_str,
504522
use_neox_rotary_style,
505523
rope_3d,
@@ -577,6 +595,9 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
577595
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
578596
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
579597
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
598+
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
599+
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
600+
const float rms_norm_eps,
580601
const std::string& compute_dtype,
581602
const std::string& cache_quant_type_str,
582603
const bool use_neox_rotary_style,
@@ -637,6 +658,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
637658
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
638659
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
639660
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
661+
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
662+
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
663+
const float rms_norm_eps,
640664
const std::string& compute_dtype,
641665
const std::string& cache_quant_type_str,
642666
const bool use_neox_rotary_style,
@@ -714,7 +738,9 @@ PD_BUILD_STATIC_OP(append_attention)
714738
paddle::Optional("cache_v_zp"),
715739
paddle::Optional("out_linear_shifts"),
716740
paddle::Optional("out_linear_smooths"),
717-
paddle::Optional("kv_signal_data")})
741+
paddle::Optional("kv_signal_data"),
742+
paddle::Optional("q_norm_weight"),
743+
paddle::Optional("k_norm_weight")})
718744
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
719745
.SetInplaceMap({{"key_cache", "key_cache_out"},
720746
{"value_cache", "value_cache_out"}})
@@ -732,7 +758,8 @@ PD_BUILD_STATIC_OP(append_attention)
732758
"encoder_max_partition_size: int",
733759
"speculate_max_draft_token_num: int",
734760
"causal: bool",
735-
"speculate_decoder: bool"})
761+
"speculate_decoder: bool",
762+
"rms_norm_eps: float"})
736763
.SetKernelFn(PD_KERNEL(AppendAttention))
737764
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
738765
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));

custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,142 @@
1818
#include "mma_tensor_op.cuh"
1919
#include "utils.cuh"
2020

21+
template <typename T, int VecSize = 1>
22+
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
23+
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
24+
// head_size]
25+
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
26+
// head_size // 2]
27+
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
28+
// head_size // 2]
29+
T* __restrict__ qkv_out,
30+
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
31+
const int* __restrict__ batch_id_per_token, // [num_tokens]
32+
const int* __restrict__ cu_seqlens_q,
33+
const int* __restrict__ seq_lens, // [bsz]
34+
const int* __restrict__ seq_lens_encoder, // [bsz]
35+
const float* __restrict__ cos_emb,
36+
const float* __restrict__ sin_emb,
37+
const int max_seq_len,
38+
const int max_blocks_per_seq,
39+
const int num_heads,
40+
const int head_size,
41+
const int block_size,
42+
const uint32_t elem_cnt,
43+
const int kv_num_heads,
44+
const bool rope_3d,
45+
const T* q_norm_weight,
46+
const T* k_norm_weight,
47+
const float rms_norm_eps) {
48+
using LoadT = AlignedVector<T, VecSize>;
49+
using LoadBiasT = AlignedVector<T, VecSize>;
50+
using LoadKVT = AlignedVector<T, VecSize>;
51+
constexpr int HalfVecSize = VecSize / 2;
52+
using LoadEmbT = AlignedVector<float, HalfVecSize>;
53+
LoadT src_vec;
54+
LoadBiasT out_vec;
55+
LoadKVT cache_vec;
56+
LoadEmbT cos_emb_vec;
57+
LoadEmbT sin_emb_vec;
58+
59+
int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x;
60+
int64_t all_warp_num = gridDim.x * blockDim.x;
61+
int64_t all_head_dim = elem_cnt / head_size;
62+
63+
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
64+
// const int64_t offset = 2 * hidden_size;
65+
const int half_head_size = head_size / 2;
66+
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
67+
int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize;
68+
const int ori_bi = linear_index / hidden_size;
69+
const int bias = linear_index % hidden_size;
70+
const int hi = bias / head_size; // q + k + v
71+
const int h_bias = bias % head_size;
72+
const int start_token_idx = cu_seqlens_q[ori_bi];
73+
if (seq_lens_encoder[ori_bi] > 0) return;
74+
const int write_seq_id = seq_lens[ori_bi];
75+
if (write_seq_id == 0) continue;
76+
77+
const int* block_table_now = nullptr;
78+
79+
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
80+
const int block_idx = block_table_now[write_seq_id / block_size];
81+
const int block_offset = write_seq_id % block_size;
82+
const uint32_t ori_idx =
83+
start_token_idx * hidden_size + hi * head_size + h_bias;
84+
85+
const int bias_idx = hi * head_size + h_bias;
86+
Load<T, VecSize>(&quant_qkv[ori_idx], &src_vec);
87+
if (hi < num_heads + kv_num_heads) {
88+
// q k rope
89+
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
90+
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
91+
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
92+
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
93+
}
94+
float thread_m2 = 0.0f;
95+
float warp_m2 = 0.0f;
96+
97+
#pragma unroll
98+
for (int i = 0; i < HalfVecSize; i++) {
99+
// dequant + add_bias + rope
100+
float input_left = static_cast<float>(src_vec[2 * i]);
101+
float input_right = static_cast<float>(src_vec[2 * i + 1]);
102+
103+
if (hi < num_heads + kv_num_heads) {
104+
const float cos_tmp = cos_emb_vec[i];
105+
const float sin_tmp = sin_emb_vec[i];
106+
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
107+
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
108+
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
109+
out_vec[2 * i] =
110+
static_cast<T>(tmp1);
111+
out_vec[2 * i + 1] =
112+
static_cast<T>(tmp2);
113+
} else {
114+
out_vec[2 * i] = src_vec[2 * i];
115+
out_vec[2 * i + 1] = src_vec[2 * i + 1];
116+
}
117+
}
118+
if (hi < (num_heads + kv_num_heads)) { // q k
119+
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
120+
float row_variance =
121+
max(warp_m2 / head_size, 0.0f);
122+
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
123+
LoadT q_norm_vec, k_norm_vec;
124+
if (hi < num_heads) { // q
125+
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
126+
#pragma unroll
127+
for (int i = 0; i < VecSize; i++) {
128+
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
129+
}
130+
} else { // k
131+
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
132+
for (int i = 0; i < VecSize; i++) {
133+
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
134+
}
135+
}
136+
}
137+
if (hi < num_heads) {
138+
// write q
139+
Store<T, VecSize>(out_vec, &qkv_out[ori_idx]);
140+
} else {
141+
// quant + write k/v
142+
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
143+
const uint32_t tgt_idx =
144+
block_idx * kv_num_heads * block_size * head_size +
145+
kv_head_idx * block_size * head_size + block_offset * head_size +
146+
h_bias;
147+
if (hi < num_heads + kv_num_heads) {
148+
Store<T, VecSize>(out_vec, &key_cache[tgt_idx]);
149+
} else {
150+
Store<T, VecSize>(out_vec, &value_cache[tgt_idx]);
151+
}
152+
}
153+
154+
}
155+
}
156+
21157
template <typename T, int VecSize = 1>
22158
__global__ void append_decode_cache_T_rope_kernel(
23159
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,

0 commit comments

Comments
 (0)