diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 5e4ce35da7..8a760d53c1 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -59,7 +59,6 @@ void AppendAttentionKernel( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& set_max_lengths, - const paddle::Tensor& max_len_kv, paddle::Tensor& fmha_out, const paddle::optional& rotary_embs, const paddle::optional& attn_mask, @@ -103,6 +102,7 @@ void AppendAttentionKernel( int max_dec_len_this_time = set_max_lengths.data()[2]; int max_enc_dec_len_this_time = set_max_lengths.data()[3]; int max_just_dec_len_this_time = set_max_lengths.data()[4]; + int max_kv_len_this_time = set_max_lengths.data()[8]; auto main_stream = qkv.stream(); static cudaEvent_t main_event; @@ -245,7 +245,6 @@ void AppendAttentionKernel( if (max_just_dec_len_this_time > 0) { int decoder_num_blocks_data = decoder_num_blocks.data()[0]; - int max_len_kv_data = max_len_kv.data()[0]; cudaStream_t exec_stream; if (max_enc_len_this_time > 0) { @@ -371,20 +370,20 @@ void AppendAttentionKernel( case paddle::DataType::INT8:{ int8_t tmp; dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream); + decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream); break; } case paddle::DataType::FLOAT8_E4M3FN:{ phi::dtype::float8_e4m3fn tmp; dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream); + decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream); break; } } } else { data_t tmp; dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream); + decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream); } if (max_enc_len_this_time > 0) { cudaEventRecord(decoder_event, exec_stream); @@ -413,7 +412,6 @@ std::vector AppendAttention( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& set_max_lengths, - const paddle::Tensor& max_len_kv, const paddle::optional& rotary_embs, const paddle::optional& attn_mask, const paddle::optional& qkv_bias, @@ -539,7 +537,6 @@ std::vector AppendAttention( decoder_tile_ids_per_batch, decoder_num_blocks, set_max_lengths, - max_len_kv, fmha_out, rotary_embs, attn_mask, @@ -616,7 +613,6 @@ void AppendAttentionWithOutput( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& set_max_lengths, - const paddle::Tensor& max_len_kv, paddle::Tensor& fmha_out, const paddle::optional& rotary_embs, const paddle::optional& attn_mask, @@ -695,7 +691,6 @@ void AppendAttentionWithOutput( decoder_tile_ids_per_batch, decoder_num_blocks, set_max_lengths, - max_len_kv, fmha_out, rotary_embs, attn_mask, @@ -784,7 +779,6 @@ std::vector> AppendAttentionInferShape( const std::vector& decoder_tile_ids_per_batch_shape, const std::vector& decoder_num_blocks_shape, const std::vector& set_max_lengths_shape, - const std::vector& max_len_kv_shape, const paddle::optional>& rotary_embs_shape, const paddle::optional>& attn_mask_shape, const paddle::optional>& qkv_bias_shape, @@ -848,7 +842,6 @@ std::vector AppendAttentionInferDtype( const paddle::DataType& decoder_tile_ids_per_batch_dtype, const paddle::DataType& decoder_num_blocks_dtype, const paddle::DataType& set_max_lengths_dtype, - const paddle::DataType& max_len_kv_dtype, const paddle::optional& rotary_embs_dtype, const paddle::optional& attn_mask_dtype, const paddle::optional& qkv_bias_dtype, @@ -930,7 +923,6 @@ std::vector> AppendAttentionWithOutputInferShape( const std::vector& decoder_tile_ids_per_batch_shape, const std::vector& decoder_num_blocks_shape, const std::vector& set_max_lengths_shape, - const std::vector& max_len_kv_shape, const std::vector& fmha_out_shape, const paddle::optional>& rotary_embs_shape, const paddle::optional>& attn_mask_shape, @@ -987,7 +979,6 @@ std::vector AppendAttentionWithOutputInferDtype( const paddle::DataType& decoder_tile_ids_per_batch_dtype, const paddle::DataType& decoder_num_blocks_dtype, const paddle::DataType& set_max_lengths_dtype, - const paddle::DataType& max_len_kv_dtype, const paddle::DataType& fmha_out_dtype, const paddle::optional& rotary_embs_dtype, const paddle::optional& attn_mask_dtype, @@ -1046,7 +1037,6 @@ PD_BUILD_STATIC_OP(append_attention) "decoder_tile_ids_per_batch", "decoder_num_blocks", "set_max_lengths", - "max_len_kv", paddle::Optional("rotary_embs"), paddle::Optional("attn_mask"), paddle::Optional("qkv_bias"), @@ -1107,7 +1097,6 @@ PD_BUILD_STATIC_OP(append_attention_with_output) "decoder_tile_ids_per_batch", "decoder_num_blocks", "set_max_lengths", - "max_len_kv", "fmha_out", paddle::Optional("rotary_embs"), paddle::Optional("attn_mask"), diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 2e2e8c7bab..60a7f0e9e0 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -19,7 +19,7 @@ template __global__ void -GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, +GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, const int *seq_lens_encoder, const int *seq_lens_this_time_merged, const int *seq_lens_encoder_merged, const int *seq_mapping, @@ -37,41 +37,28 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, int max_just_dec_merged_len_this_time_this_thread = 0; int max_system_len_this_thread = 0; int max_dec_len_without_system_this_thread = 0; + int max_len_kv_this_thread = 0; for (int i = tid; i < batch_size; i += blockDim.x) { const int seq_len_this_time = seq_lens_this_time[i]; + const int seq_len_decoder = seq_lens_decoder[i]; max_len_this_time_this_thread = max(seq_len_this_time, max_len_this_time_this_thread); max_len_encoder_this_thread = max(seq_lens_encoder[i], max_len_encoder_this_thread); - max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread); + max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread); + if (seq_len_this_time <= 0) continue; - const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i]; + const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; max_len_this_thread = - max(seq_lens[i] + seq_len_this_time, max_len_this_thread); + max(seq_len_decoder + seq_len_this_time, max_len_this_thread); max_just_dec_len_this_thread = max(max_just_dec_len_this_thread, max_just_dec_len_now); - if (system_lens) { - const int real_bid = seq_mapping[i]; - const int system_len_now = system_lens[real_bid]; - max_system_len_this_thread = - max(max_system_len_this_thread, system_len_now); - max_dec_len_without_system_this_thread = - max(max_dec_len_without_system_this_thread, - max_just_dec_len_now - system_len_now); - } - } - if (system_lens) { - for (int i = tid; i < batch_size; i += blockDim.x) { - const int ori_seq_len_this_time = seq_lens_this_time_merged[i]; - if (ori_seq_len_this_time <= 0) - continue; - const int max_just_dec_merged_len_this_time_now = - seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time; - max_just_dec_merged_len_this_time_this_thread = - max(max_just_dec_merged_len_this_time_this_thread, - max_just_dec_merged_len_this_time_now); - } + + if (seq_len_decoder == 0) + continue; + max_len_kv_this_thread = + max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); } int total_max_len_this_time = BlockReduce(temp_storage) @@ -86,23 +73,18 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); int total_just_dec = BlockReduce(temp_storage) .Reduce(max_just_dec_len_this_thread, MaxOp()); - int total_just_dec_merged = - BlockReduce(temp_storage) - .Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp()); - int total_system_len = BlockReduce(temp_storage) - .Reduce(max_system_len_this_thread, MaxOp()); - int total_dec_len_without_system = - BlockReduce(temp_storage) - .Reduce(max_dec_len_without_system_this_thread, MaxOp()); + int total_max_len_kv = + BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); if (tid == 0) { max_lens[0] = total_max_len_this_time; max_lens[1] = total_max_len_encoder; max_lens[2] = total_max_len_decoder; max_lens[3] = total; max_lens[4] = total_just_dec; - max_lens[5] = total_just_dec_merged; - max_lens[6] = total_system_len; - max_lens[7] = total_dec_len_without_system; + max_lens[5] = max_just_dec_merged_len_this_time_this_thread; + max_lens[6] = max_system_len_this_thread; + max_lens[7] = max_dec_len_without_system_this_thread; + max_lens[8] = total_max_len_kv; } } @@ -208,25 +190,68 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q, const int *__restrict__ seq_lens_encoder, int *__restrict__ batch_ids, int *__restrict__ tile_ids_per_batch, - int *__restrict__ num_blocks_x, const int bsz, + int *__restrict__ num_blocks_x, + const int bsz, const int num_rows_per_block, const int group_size) { - if (threadIdx.x == 0) { - int gridx = 0; - int index = 0; - for (uint32_t bid = 0; bid < bsz; bid++) { + // one block one warp + const int lane = threadIdx.x % warpSize; + + __shared__ int global_offset; + if (threadIdx.x == 0) global_offset = 0; + __syncthreads(); + + // loop on warp tile:[base, base+32) + for (int base = 0; base < bsz; base += warpSize) { + const int bid = base + lane; + const bool active = (bid < bsz); + + // calculate loop_times for bid + int loop_times = 0; + if (active) { int seq_len = seq_lens_q[bid]; if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { seq_len = 0; } - const int loop_times = div_up(seq_len * group_size, num_rows_per_block); - for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { - batch_ids[index] = bid; - tile_ids_per_batch[index++] = tile_id; + loop_times = div_up(seq_len * group_size, num_rows_per_block); + } + + // prefix sum for each lane, get the start offset in this tile + unsigned mask = __ballot_sync(0xffffffff, active); + // inclusive scan + int x = loop_times; + for (int offset = 1; offset < warpSize; offset <<= 1) { + int y = __shfl_up_sync(mask, x, offset); + if (lane >= offset) x += y; + } + int excl = x - loop_times; // exclusive prefix + int tile_sum = __reduce_add_sync(mask, loop_times); // warp tile sum + + // write batch_ids and tile_ids_per_batch + int base_offset; + if (lane == 0) { + base_offset = global_offset; + } + base_offset = __shfl_sync(mask, base_offset, 0); + if (active && loop_times > 0) { + int write_base = base_offset + excl; + // [write_base, write_base+loop_times) + for (int t = 0; t < loop_times; ++t) { + int pos = write_base + t; + batch_ids[pos] = bid; + tile_ids_per_batch[pos] = t; } - gridx += loop_times; } - *num_blocks_x = gridx; + + // for next warp tile + if (lane == 0) { + global_offset += tile_sum; + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + *num_blocks_x = global_offset; } } @@ -256,29 +281,6 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder, } } -template -__global__ void -get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time, - const int *seq_lens_decoder, const int batch_size) { - const int tid = threadIdx.x; - - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int max_len_this_thread = 0; - for (int i = tid; i < batch_size; i += blockDim.x) { - if (seq_lens_decoder[i] == 0) - continue; - max_len_this_thread = - max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread); - } - int total = - BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); - if (tid == 0) { - *max_seq_lens_out = total; - } -} - void GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, @@ -295,7 +297,6 @@ void GetBlockShapeAndSplitKVBlock( paddle::Tensor &kv_batch_ids, // Inplace paddle::Tensor &kv_tile_ids_per_batch, // Inplace paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU - paddle::Tensor &max_len_kv_cpu, // Inplace, CPU const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, @@ -319,15 +320,7 @@ void GetBlockShapeAndSplitKVBlock( int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5]; int max_system_len = max_len_cpu_ptr[6]; int max_just_dec_len_without_system = max_len_cpu_ptr[7]; - - auto max_len_kv = - GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); - get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>( - max_len_kv.data(), seq_lens_this_time.data(), - seq_lens_decoder.data(), bsz); - - - max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false); + int max_kv_len_this_time = max_len_cpu_ptr[8]; // decoder if (max_dec_len_this_time > 0) { @@ -416,12 +409,8 @@ void GetBlockShapeAndSplitKVBlock( decoder_num_blocks_cpu.copy_( decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); } } else { - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); decoder_num_blocks_cpu.copy_( @@ -479,7 +468,6 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) "kv_batch_ids", "kv_tile_ids_per_batch", "kv_num_blocks_x_cpu", - "max_len_kv_cpu" }) .Outputs({ diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 85f88cf123..f285016992 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -64,7 +64,7 @@ std::vector AppendAttention( const paddle::Tensor &decoder_batch_ids, const paddle::Tensor &decoder_tile_ids_per_batch, const paddle::Tensor &decoder_num_blocks_cpu, - const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv, + const paddle::Tensor &set_max_lengths, const paddle::optional &rotary_embs, const paddle::optional &attn_mask, const paddle::optional &qkv_bias, @@ -106,7 +106,7 @@ void AppendAttentionWithOutput( const paddle::Tensor &decoder_batch_ids, const paddle::Tensor &decoder_tile_ids_per_batch, const paddle::Tensor &decoder_num_blocks_cpu, - const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv, + const paddle::Tensor &set_max_lengths, paddle::Tensor &fmha_out, const paddle::optional &rotary_embs, const paddle::optional &attn_mask, @@ -315,7 +315,6 @@ void GetBlockShapeAndSplitKVBlock( paddle::Tensor &kv_batch_ids, // Inplace paddle::Tensor &kv_tile_ids_per_batch, // Inplace paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory - paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 10608676c0..a4351dcfd0 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -119,8 +119,6 @@ class ForwardMeta: kv_tile_ids_per_batch: Optional[paddle.Tensor] = None # The number of CUDA blocks to launch in the x-dimension for the append_write_cache_kv kernel, defining its grids.x. kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None - # The maximum sequence length of the KV cache, which may represent the current maximum decoder length. - max_len_kv_cpu: Optional[paddle.Tensor] = None decoder_chunk_size_device: Optional[paddle.Tensor] = None diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index d42c4b80c5..b412e63247 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -150,7 +150,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, - forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, @@ -291,7 +290,6 @@ def forward_mixed( forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, - forward_meta.max_len_kv_cpu, res, metadata.rotary_embs, metadata.attn_mask, @@ -347,7 +345,6 @@ def forward_mixed( forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, - forward_meta.max_len_kv_cpu, metadata.rotary_embs, metadata.attn_mask, layer.qkv_bias, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 15750d090d..be7abe6572 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -207,7 +207,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, - forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, @@ -340,7 +339,6 @@ def forward_mixed( forward_meta.decoder_tile_ids_per_batch, # from buffer forward_meta.decoder_num_blocks_cpu, metadata.max_len_tensor_cpu_decoder, - forward_meta.max_len_kv_cpu, metadata.rotary_embs, forward_meta.attn_mask, layer.qkv_bias, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 8967429624..0fc5489bb6 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -83,6 +83,7 @@ class MLAAttentionMetadata(AttentionMetadata): max_enc_len_this_time: Optional[paddle.Tensor] = None max_dec_len_this_time: Optional[paddle.Tensor] = None + max_kv_len_this_time: Optional[paddle.Tensor] = None class MLAAttentionBackend(AttentionBackend): @@ -199,7 +200,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, - forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, @@ -210,6 +210,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # MLA metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] + metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8] # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers @@ -362,7 +363,7 @@ def forward_decode( forward_meta.decoder_num_blocks_device, forward_meta.decoder_chunk_size_device, metadata.max_dec_len_this_time, - forward_meta.max_len_kv_cpu, + metadata.max_kv_len_this_time, None, # attn_mask None, # qkv_bias None, # qkv_out_scales @@ -478,7 +479,7 @@ def forward_mixed( forward_meta.decoder_num_blocks_device, forward_meta.decoder_chunk_size_device, metadata.max_dec_len_this_time, - forward_meta.max_len_kv_cpu, + metadata.max_kv_len_this_time, None, # attn_mask None, # qkv_bias None, # qkv_out_scales diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index 7cf9636876..6216d0cd10 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -49,7 +49,6 @@ def append_attention( decoder_tile_ids_per_batch: paddle.Tensor, decoder_num_blocks: paddle.Tensor, set_max_lengths: paddle.Tensor, - max_len_kv: paddle.Tensor, rotary_embs: Optional[paddle.Tensor] = None, attn_mask: Optional[paddle.Tensor] = None, qkv_bias: Optional[paddle.Tensor] = None, @@ -107,7 +106,6 @@ def append_attention( decoder_tile_ids_per_batch, decoder_num_blocks, set_max_lengths, - max_len_kv, rotary_embs, attn_mask, qkv_bias, @@ -169,7 +167,6 @@ def append_attention_with_output( decoder_tile_ids_per_batch: paddle.Tensor, decoder_num_blocks: paddle.Tensor, set_max_lengths: paddle.Tensor, - max_len_kv: paddle.Tensor, out: paddle.tensor, # attention output rotary_embs: Optional[paddle.Tensor] = None, attn_mask: Optional[paddle.Tensor] = None, @@ -228,7 +225,6 @@ def append_attention_with_output( decoder_tile_ids_per_batch, decoder_num_blocks, set_max_lengths, - max_len_kv, out, rotary_embs, attn_mask, diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index edcf8a692f..1cd5f4f142 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -40,7 +40,6 @@ def get_block_shape_and_split_kv_block( kv_batch_ids: paddle.Tensor, kv_tile_ids_per_batch: paddle.Tensor, kv_num_blocks_x_cpu: paddle.Tensor, - max_len_kv_cpu: paddle.Tensor, encoder_block_shape_q: int, decoder_block_shape_q: int, group_size: int, @@ -67,7 +66,6 @@ def get_block_shape_and_split_kv_block( kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks_x_cpu, - max_len_kv_cpu, encoder_block_shape_q, decoder_block_shape_q, group_size, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 92eac7ddab..0a877d137a 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -872,7 +872,6 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["kv_batch_ids"] = None self.share_inputs["kv_tile_ids_per_batch"] = None self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU - self.share_inputs["max_len_kv_cpu"] = None # CPU # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) @@ -1118,7 +1117,6 @@ def initialize_forward_meta(self): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"], ) # Update Batch type for cuda graph for only_decode_batch @@ -1251,7 +1249,7 @@ def initialize_attn_backend(self) -> None: # adapted to cudagraph. self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32") self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32") - self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + self.share_inputs["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu() self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") @@ -1260,7 +1258,6 @@ def initialize_attn_backend(self) -> None: self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() - self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu() # Get the attention backend attn_cls = get_attention_backend() diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index 6da6681e7a..5a9ac5afcd 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -386,7 +386,7 @@ def init_tensor(self): self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory() self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") - self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + self.max_len_tensor_cpu = paddle.full([9], 0, dtype="int32").cpu() self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") @@ -394,7 +394,6 @@ def init_tensor(self): self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() - self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu() self.cache_shape = ( self.max_block_num, @@ -495,7 +494,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask self.kv_batch_ids, self.kv_tile_ids_per_batch, self.kv_num_blocks_x_cpu, - self.max_len_kv_cpu, 64, 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, @@ -529,7 +527,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, - self.max_len_kv_cpu, self.rope_emb, # rope_emb None, # attn_mask None, # qkv_bias @@ -591,7 +588,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, - self.max_len_kv_cpu, self.rope_emb, # rope_emb None, # attn_mask None, # qkv_bias diff --git a/tests/layers/test_append_attention_with_output.py b/tests/layers/test_append_attention_with_output.py index c198d1291d..1256a299ab 100644 --- a/tests/layers/test_append_attention_with_output.py +++ b/tests/layers/test_append_attention_with_output.py @@ -384,15 +384,13 @@ def init_tensor(self): self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory() self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") - self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + self.max_len_tensor_cpu = paddle.full([9], 0, dtype="int32").cpu() self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() - self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu() - self.cache_shape = ( self.max_block_num, self.kv_num_head, @@ -476,7 +474,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask self.kv_batch_ids, self.kv_tile_ids_per_batch, self.kv_num_blocks_x_cpu, - self.max_len_kv_cpu, 64, 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, @@ -512,7 +509,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, self.max_len_tensor_cpu, - self.max_len_kv_cpu, out, self.rope_emb, # rope_emb None, # attn_mask diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index a6bb8bd46f..795c2354e8 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -204,14 +204,13 @@ def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") - max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([9], 0, dtype="int32").cpu() encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() kv_batch_ids = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") kv_tile_ids_per_batch = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() - max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu() q_norm_weight = np.ones([self.head_dim]) k_norm_weight = np.ones([self.head_dim]) self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32") @@ -233,7 +232,6 @@ def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks_x_cpu, - max_len_kv_cpu, encoder_block_shape_q, decoder_block_shape_q, self.num_q_head // self.num_kv_head, @@ -264,7 +262,6 @@ def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, decoder_tile_ids_per_batch, decoder_num_blocks, max_len_tensor_cpu, - max_len_kv_cpu, rotary_embs, attn_mask, None, # qkv_bias