diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 4fb5c93d03..cb988baa5c 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -609,6 +609,273 @@ __global__ void append_speculate_cache_neox_rope_kernel( } } +template +__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( + const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 * + // gqa_group_size, head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + T* __restrict__ cache_k_scale, + T* __restrict__ cache_v_scale, + const float* q_norm_weight, + const float* k_norm_weight, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size, + const bool rope_3d, + const float rms_norm_eps) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int token_id = blockIdx.x; + + const int bid = batch_id_per_token[token_id]; + + const int start_token_idx = cu_seqlens_q[bid]; + const int head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid] + token_id - start_token_idx; + if (write_seq_id == 0) return; + const int* block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + int cache_offset; + if (head_idx < num_heads) { + cache_offset = 0; + } else if (head_idx < num_heads + 2 * gqa_group_size) { + cache_offset = block_idx * gqa_group_size * block_size + (head_idx - num_heads) % gqa_group_size * block_size + block_offset; + } + T *cache_k_scale_now = cache_k_scale + cache_offset; + T *cache_v_scale_now = cache_v_scale + cache_offset; + + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; + + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT bias_vec; + LoadOutScaleT out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const T* qkv_now = quant_qkv + token_id * hidden_size; + T* qkv_out_now = qkv_out + token_id * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; + head_bias += 32 * VecSize) { + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + + // q rope + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + bias_vec[2 * i] = + static_cast(tmp1); + bias_vec[2 * i + 1] = + static_cast(tmp2); + } + // qk norm + if (q_norm_weight) { + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = + max(warp_m2 / HeadDim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + LoadOutScaleT q_norm_vec; + Load(&q_norm_weight[lane_id * VecSize], &q_norm_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + bias_vec[i] = static_cast(static_cast(bias_vec[i]) * row_inv_var * q_norm_vec[i]); + } + } + Store(bias_vec, &qkv_out_now[bias_idx]); + } + } else if (head_idx < num_heads + 2 * gqa_group_size) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec2; + LoadBiasT bias_vec1, bias_vec2; + LoadOutScaleT out_scale_vec1, out_scale_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const T* qkv_now = quant_qkv + token_id * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + T scale = T(1.0f); + const int k_head_idx = head_idx - num_heads; + const int v_head_idx = head_idx - num_heads - gqa_group_size; + if (head_idx < num_heads + gqa_group_size) { + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); + } + + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1[1]); + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + bias_vec1[0] = + static_cast(tmp1); + bias_vec1[1] = + static_cast(tmp2); + } else { + bias_vec1[0] = static_cast(input_left); + bias_vec1[1] = static_cast(input_right); + } + + input_left = static_cast(src_vec2[0]); + input_right = static_cast(src_vec2[1]); + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + bias_vec2[0] = + static_cast(tmp1); + bias_vec2[1] = + static_cast(tmp2); + } else { + bias_vec2[0] = static_cast(input_left); + bias_vec2[1] = static_cast(input_right); + } + if (k_norm_weight) { + if (head_idx < num_heads + gqa_group_size) { + LoadOutScaleT k_norm_vec1, k_norm_vec2; + Load(&k_norm_weight[head_bias], &k_norm_vec1); + Load(&k_norm_weight[head_bias + 8], &k_norm_vec2); + // qk norm + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = + max(warp_m2 / HeadDim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + bias_vec1[i] = static_cast(static_cast(bias_vec1[i]) * row_inv_var * k_norm_vec1[i]); + bias_vec2[i] = static_cast(static_cast(bias_vec2[i]) * row_inv_var * k_norm_vec2[i]); + } + } + } + // reduce max, 1 head per warp + T local_max = -INFINITY; +#pragma unroll + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + local_max = __hmax(local_max, __habs(bias_vec1[i])); + local_max = __hmax(local_max, __habs(bias_vec2[i])); + } +#pragma unroll + for (int m_offset = 16; m_offset > 0; m_offset /= 2) { + local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); + } + + scale = __hdiv(448, local_max); + + if (lane_id == 0) { + if (head_idx < num_heads + gqa_group_size) { + cache_k_scale_now[0] = __hdiv(1, scale); + } else { + cache_v_scale_now[0] = __hdiv(1, scale); + } + } + +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + cache_vec[i] = QuantToC8(scale, bias_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8(scale, bias_vec2[i], max_bound, min_bound); + } + if (head_idx < num_heads + gqa_group_size) { + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * gqa_group_size * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + const uint32_t base_tgt_cache_idx = + block_idx * gqa_group_size * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } + } +} + template +void append_speculate_cache_fp8_dynamic_rope(const T* qkv, + uint8_t* key_cache, + uint8_t* value_cache, + T* qkv_out, + const int* block_tables, + const int* batch_id_per_token, + const int* cu_seqlens_q, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + T* cache_k_scale, + T* cache_v_scale, + const float* q_norm_weight, + const float* k_norm_weight, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int dim_head, + const int block_size, + const int bsz, + const int token_num, + const cudaStream_t& stream, + const bool rope_3d, + const float rms_norm_eps) { + constexpr int num_warps = 4; + const int all_warps = + ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; + dim3 grids(token_num, all_warps / num_warps); + + append_clear_cache_int8_block<4> + <<>>(key_cache, + value_cache, + seq_lens, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + kv_num_heads); + append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel + <<>>(qkv, + key_cache, + value_cache, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + q_norm_weight, + k_norm_weight, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); +} + template void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, uint8_t* key_cache, @@ -459,6 +531,35 @@ void SpeculateWriteCacheWithRoPEKernel( reinterpret_cast(k_norm_weight.get().data()), rms_norm_eps, rope_3d); + } else if (cache_quant_type_str == "block_wise_fp8") { + append_speculate_cache_fp8_dynamic_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast(cache_k_scale.get().data())), + const_cast(reinterpret_cast(cache_v_scale.get().data())), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + rope_3d, + rms_norm_eps + ); } else { PD_THROW( "append_decode_cache_rope_qk_norm not support cachekv quant yet"); @@ -561,6 +662,35 @@ void SpeculateWriteCacheWithRoPEKernel( stream, use_neox_rotary_style, rope_3d); + } else if (cache_quant_type_str == "block_wise_fp8") { + append_speculate_cache_fp8_dynamic_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast(cache_k_scale.get().data())), + const_cast(reinterpret_cast(cache_v_scale.get().data())), + nullptr, // q_norm_weight + nullptr, // k_norm_weight + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + rope_3d, + rms_norm_eps + ); } else if (cache_quant_type_str == "cache_int4_zp") { append_speculate_cache_int4_rope( reinterpret_cast(qkv_ptr),