|
18 | 18 | #include "mma_tensor_op.cuh"
|
19 | 19 | #include "utils.cuh"
|
20 | 20 |
|
| 21 | +template <typename T, int VecSize = 1> |
| 22 | +__global__ void append_decode_cache_T_rope_qk_norm_kernel( |
| 23 | + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, |
| 24 | + // head_size] |
| 25 | + T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, |
| 26 | + // head_size // 2] |
| 27 | + T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, |
| 28 | + // head_size // 2] |
| 29 | + T* __restrict__ qkv_out, |
| 30 | + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] |
| 31 | + const int* __restrict__ batch_id_per_token, // [num_tokens] |
| 32 | + const int* __restrict__ cu_seqlens_q, |
| 33 | + const int* __restrict__ seq_lens, // [bsz] |
| 34 | + const int* __restrict__ seq_lens_encoder, // [bsz] |
| 35 | + const float* __restrict__ cos_emb, |
| 36 | + const float* __restrict__ sin_emb, |
| 37 | + const int max_seq_len, |
| 38 | + const int max_blocks_per_seq, |
| 39 | + const int num_heads, |
| 40 | + const int head_size, |
| 41 | + const int block_size, |
| 42 | + const uint32_t elem_cnt, |
| 43 | + const int kv_num_heads, |
| 44 | + const bool rope_3d, |
| 45 | + const T* q_norm_weight, |
| 46 | + const T* k_norm_weight, |
| 47 | + const float rms_norm_eps) { |
| 48 | + using LoadT = AlignedVector<T, VecSize>; |
| 49 | + using LoadBiasT = AlignedVector<T, VecSize>; |
| 50 | + using LoadKVT = AlignedVector<T, VecSize>; |
| 51 | + constexpr int HalfVecSize = VecSize / 2; |
| 52 | + using LoadEmbT = AlignedVector<float, HalfVecSize>; |
| 53 | + LoadT src_vec; |
| 54 | + LoadBiasT out_vec; |
| 55 | + LoadKVT cache_vec; |
| 56 | + LoadEmbT cos_emb_vec; |
| 57 | + LoadEmbT sin_emb_vec; |
| 58 | + |
| 59 | + int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 60 | + int64_t all_warp_num = gridDim.x * blockDim.x; |
| 61 | + int64_t all_head_dim = elem_cnt / head_size; |
| 62 | + |
| 63 | + const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; |
| 64 | + // const int64_t offset = 2 * hidden_size; |
| 65 | + const int half_head_size = head_size / 2; |
| 66 | + for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) { |
| 67 | + int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize; |
| 68 | + const int ori_bi = linear_index / hidden_size; |
| 69 | + const int bias = linear_index % hidden_size; |
| 70 | + const int hi = bias / head_size; // q + k + v |
| 71 | + const int h_bias = bias % head_size; |
| 72 | + const int start_token_idx = cu_seqlens_q[ori_bi]; |
| 73 | + if (seq_lens_encoder[ori_bi] > 0) return; |
| 74 | + const int write_seq_id = seq_lens[ori_bi]; |
| 75 | + if (write_seq_id == 0) continue; |
| 76 | + |
| 77 | + const int* block_table_now = nullptr; |
| 78 | + |
| 79 | + block_table_now = block_tables + ori_bi * max_blocks_per_seq; |
| 80 | + const int block_idx = block_table_now[write_seq_id / block_size]; |
| 81 | + const int block_offset = write_seq_id % block_size; |
| 82 | + const uint32_t ori_idx = |
| 83 | + start_token_idx * hidden_size + hi * head_size + h_bias; |
| 84 | + |
| 85 | + const int bias_idx = hi * head_size + h_bias; |
| 86 | + Load<T, VecSize>(&quant_qkv[ori_idx], &src_vec); |
| 87 | + if (hi < num_heads + kv_num_heads) { |
| 88 | + // q k rope |
| 89 | + const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2; |
| 90 | + uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; |
| 91 | + Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec); |
| 92 | + Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec); |
| 93 | + } |
| 94 | + float thread_m2 = 0.0f; |
| 95 | + float warp_m2 = 0.0f; |
| 96 | + |
| 97 | +#pragma unroll |
| 98 | + for (int i = 0; i < HalfVecSize; i++) { |
| 99 | + // dequant + add_bias + rope |
| 100 | + float input_left = static_cast<float>(src_vec[2 * i]); |
| 101 | + float input_right = static_cast<float>(src_vec[2 * i + 1]); |
| 102 | + |
| 103 | + if (hi < num_heads + kv_num_heads) { |
| 104 | + const float cos_tmp = cos_emb_vec[i]; |
| 105 | + const float sin_tmp = sin_emb_vec[i]; |
| 106 | + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; |
| 107 | + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; |
| 108 | + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; |
| 109 | + out_vec[2 * i] = |
| 110 | + static_cast<T>(tmp1); |
| 111 | + out_vec[2 * i + 1] = |
| 112 | + static_cast<T>(tmp2); |
| 113 | + } else { |
| 114 | + out_vec[2 * i] = src_vec[2 * i]; |
| 115 | + out_vec[2 * i + 1] = src_vec[2 * i + 1]; |
| 116 | + } |
| 117 | + } |
| 118 | + if (hi < (num_heads + kv_num_heads)) { // q k |
| 119 | + WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2); |
| 120 | + float row_variance = |
| 121 | + max(warp_m2 / head_size, 0.0f); |
| 122 | + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); |
| 123 | + LoadT q_norm_vec, k_norm_vec; |
| 124 | + if (hi < num_heads) { // q |
| 125 | + Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec); |
| 126 | + #pragma unroll |
| 127 | + for (int i = 0; i < VecSize; i++) { |
| 128 | + out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i])); |
| 129 | + } |
| 130 | + } else { // k |
| 131 | + Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec); |
| 132 | + for (int i = 0; i < VecSize; i++) { |
| 133 | + out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i])); |
| 134 | + } |
| 135 | + } |
| 136 | + } |
| 137 | + if (hi < num_heads) { |
| 138 | + // write q |
| 139 | + Store<T, VecSize>(out_vec, &qkv_out[ori_idx]); |
| 140 | + } else { |
| 141 | + // quant + write k/v |
| 142 | + const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads; |
| 143 | + const uint32_t tgt_idx = |
| 144 | + block_idx * kv_num_heads * block_size * head_size + |
| 145 | + kv_head_idx * block_size * head_size + block_offset * head_size + |
| 146 | + h_bias; |
| 147 | + if (hi < num_heads + kv_num_heads) { |
| 148 | + Store<T, VecSize>(out_vec, &key_cache[tgt_idx]); |
| 149 | + } else { |
| 150 | + Store<T, VecSize>(out_vec, &value_cache[tgt_idx]); |
| 151 | + } |
| 152 | + } |
| 153 | + |
| 154 | + } |
| 155 | +} |
| 156 | + |
21 | 157 | template <typename T, int VecSize = 1>
|
22 | 158 | __global__ void append_decode_cache_T_rope_kernel(
|
23 | 159 | const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
|
0 commit comments