diff --git a/custom_ops/cpu_ops/get_padding_offset.cc b/custom_ops/cpu_ops/get_padding_offset.cc index 8fe73bc8e4..02ee71a263 100644 --- a/custom_ops/cpu_ops/get_padding_offset.cc +++ b/custom_ops/cpu_ops/get_padding_offset.cc @@ -84,7 +84,6 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, seq_length, bsz); return {x_remove_padding, - cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; @@ -97,7 +96,7 @@ std::vector> GetPaddingOffsetInferShape( const std::vector &seq_len_shape) { int64_t bsz = seq_len_shape[0]; int64_t seq_len = input_ids_shape[1]; - return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; + return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector GetPaddingOffsetInferDtype( @@ -106,7 +105,6 @@ std::vector GetPaddingOffsetInferDtype( const paddle::DataType &token_num_dtype, const paddle::DataType &seq_len_dtype) { return {input_ids_dtype, - seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; @@ -115,7 +113,6 @@ std::vector GetPaddingOffsetInferDtype( PD_BUILD_STATIC_OP(get_padding_offset_cpu) .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) .Outputs({"x_remove_padding", - "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"}) diff --git a/custom_ops/cpu_ops/rebuild_padding.cc b/custom_ops/cpu_ops/rebuild_padding.cc index 8ce533d041..adbf95e5fd 100644 --- a/custom_ops/cpu_ops/rebuild_padding.cc +++ b/custom_ops/cpu_ops/rebuild_padding.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ template void RebuildPaddingCPUImpl(T *output_data, const T *input_data, - const int *cum_offsets_data, + const int *cu_seqlens_q_data, const int *seq_len_this_time_data, const int *seq_lens_decoder_data, const int *seq_lens_encoder_data, @@ -40,11 +40,12 @@ void RebuildPaddingCPUImpl(T *output_data, if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) { continue; } + if (seq_lens_encoder_data[bi] > 0) { seq_id = seq_lens_encoder_data[bi] - 1; } - const int ori_token_idx = - bi * max_input_length - cum_offsets_data[bi] + seq_id; + + const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id; const int src_offset = ori_token_idx * dim_embed + bias_idx; output_data[i] = input_data[src_offset]; @@ -54,7 +55,7 @@ void RebuildPaddingCPUImpl(T *output_data, template void RebuildAppendPaddingCPUImpl(T *output_data, const T *input_data, - const int *cum_offsets_data, + const int *cu_seqlens_q_data, const int *seq_len_this_time_data, const int *seq_lens_decoder_data, const int *seq_lens_encoder_data, @@ -69,30 +70,32 @@ void RebuildAppendPaddingCPUImpl(T *output_data, int bi = ori_token_id / max_input_length; if (seq_len_this_time_data[bi] == 0 || (seq_lens_decoder_data[bi] == 0 && - seq_lens_encoder_data[bi] == 0)) { - continue; - } + seq_lens_encoder_data[bi] == 0)) { + continue; + } int seq_id = 0; + if (seq_lens_encoder_data[bi] > 0) { seq_id = seq_lens_encoder_data[bi] - 1; } - int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id; + int input_token_id = cu_seqlens_q_data[bi] + seq_id; int bias_idx = i % dim_embed; int src_offset = input_token_id * dim_embed + bias_idx; + output_data[i] = input_data[src_offset]; } } std::vector RebuildPaddingCPU( const paddle::Tensor &tmp_out, - const paddle::Tensor &cum_offsets, + const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, int max_input_length) { auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true); - auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true); + auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true); auto seq_len_this_time_cpu = seq_len_this_time.copy_to(paddle::CPUPlace(), true); auto seq_lens_decoder_cpu = @@ -107,7 +110,7 @@ std::vector RebuildPaddingCPU( int token_num = tmp_out_cpu.shape()[0]; int dim_embed = tmp_out_cpu.shape()[1]; - int bsz = cum_offsets_cpu.shape()[0]; + int bsz = cu_seqlens_q_cpu.shape()[0] - 1; paddle::Tensor out; if (output_padding_offset_cpu) { @@ -128,7 +131,7 @@ std::vector RebuildPaddingCPU( {bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace()); } - const int *cum_offsets_data = cum_offsets_cpu.data(); + const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data(); const int *seq_len_this_time_data = seq_len_this_time_cpu.data(); const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data(); const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data(); @@ -141,7 +144,7 @@ std::vector RebuildPaddingCPU( case paddle::DataType::FLOAT32: RebuildAppendPaddingCPUImpl(out.data(), tmp_out_cpu.data(), - cum_offsets_data, + cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, @@ -154,7 +157,7 @@ std::vector RebuildPaddingCPU( RebuildAppendPaddingCPUImpl( out.data(), tmp_out_cpu.data(), - cum_offsets_data, + cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, @@ -167,7 +170,7 @@ std::vector RebuildPaddingCPU( RebuildAppendPaddingCPUImpl( out.data(), tmp_out_cpu.data(), - cum_offsets_data, + cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, @@ -186,7 +189,7 @@ std::vector RebuildPaddingCPU( case paddle::DataType::FLOAT32: RebuildPaddingCPUImpl(out.data(), tmp_out_cpu.data(), - cum_offsets_data, + cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, @@ -198,7 +201,7 @@ std::vector RebuildPaddingCPU( RebuildPaddingCPUImpl( out.data(), tmp_out_cpu.data(), - cum_offsets_data, + cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, @@ -207,11 +210,10 @@ std::vector RebuildPaddingCPU( elem_nums); break; case paddle::DataType::BFLOAT16: - RebuildPaddingCPUImpl( out.data(), tmp_out_cpu.data(), - cum_offsets_data, + cu_seqlens_q_data, seq_len_this_time_data, seq_lens_decoder_data, seq_lens_encoder_data, @@ -230,7 +232,7 @@ std::vector RebuildPaddingCPU( std::vector> RebuildPaddingInferShape( const std::vector &tmp_out_shape, - const std::vector &cum_offsets_shape, + const std::vector &cu_seqlens_q_shape, const std::vector &seq_len_this_time_shape, const std::vector &seq_lens_decoder_shape, const std::vector &seq_lens_encoder_shape, @@ -239,14 +241,14 @@ std::vector> RebuildPaddingInferShape( if (output_padding_offset_shape) { return {{-1, dim_embed}}; } else { - int64_t bsz = cum_offsets_shape[0]; + int64_t bsz = cu_seqlens_q_shape[0] - 1; return {{bsz, dim_embed}}; } } std::vector RebuildPaddingInferDtype( const paddle::DataType &tmp_out_dtype, - const paddle::DataType &cum_offsets_dtype, + const paddle::DataType &cu_seqlens_q_dtype, const paddle::DataType &seq_len_this_time_dtype, const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_encoder_dtype, @@ -256,7 +258,7 @@ std::vector RebuildPaddingInferDtype( PD_BUILD_STATIC_OP(rebuild_padding_cpu) .Inputs({"tmp_out", - "cum_offsets", + "cu_seqlens_q", "seq_len_this_time", "seq_lens_decoder", "seq_lens_encoder", diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index 8fae9b88c3..f505e1c326 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -101,7 +101,6 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, cum_offsets_out.data(), seq_length); return {x_remove_padding, - cum_offsets_out, batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num}; @@ -114,7 +113,7 @@ std::vector> GetPaddingOffsetInferShape( const std::vector &seq_len_shape) { int64_t bsz = seq_len_shape[0]; int64_t seq_len = input_ids_shape[1]; - return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; + return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector GetPaddingOffsetInferDtype( @@ -123,7 +122,6 @@ std::vector GetPaddingOffsetInferDtype( const paddle::DataType &token_num_dtype, const paddle::DataType &seq_len_dtype) { return {input_ids_dtype, - seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; @@ -132,7 +130,6 @@ std::vector GetPaddingOffsetInferDtype( PD_BUILD_STATIC_OP(get_padding_offset) .Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"}) .Outputs({"x_remove_padding", - "cum_offsets_out", "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index 3d69e9e459..772fefa1ac 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -14,10 +14,11 @@ #include "helper.h" // NOLINT + template __global__ void RebuildPaddingKernel(T *output_data, const T *input_data, - const int *cum_offsets, + const int *cu_seqlens_q, const int *seq_len_this_time, const int *seq_len_decoder, const int *seq_len_encoder, @@ -34,10 +35,10 @@ __global__ void RebuildPaddingKernel(T *output_data, int seq_id = 0; if (seq_len_this_time[bi] == 0) continue; if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; - // if encoder, get last token; just decoder, get first token. if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; + const int ori_token_idx = - bi * max_input_length - cum_offsets[bi] + seq_id; + cu_seqlens_q[bi] + seq_id; const int src_offset = ori_token_idx * dim_embed + bias_idx; Load(&input_data[src_offset], &src_vec); Store(src_vec, &output_data[i]); @@ -47,29 +48,31 @@ __global__ void RebuildPaddingKernel(T *output_data, template __global__ void RebuildAppendPaddingKernel(T *output_data, const T *input_data, - const int *cum_offset, + const int *cu_seqlens_q, const int *seq_len_this_time, const int *seq_len_decoder, const int *seq_len_encoder, const int *output_padding_offset, const int max_input_length, const int dim_embed, - const int64_t output_elem_nums) { + const int64_t output_elem_nums, + const int bsz) { AlignedVector src_vec; const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; for (int64_t i = global_idx * VecSize; i < output_elem_nums; i += gridDim.x * blockDim.x * VecSize) { const int out_token_id = i / dim_embed; - const int ori_token_id = - out_token_id + output_padding_offset[out_token_id]; + const int ori_token_id = out_token_id + output_padding_offset[out_token_id]; + const int bi = ori_token_id / max_input_length; + int seq_id = 0; if (seq_len_this_time[bi] == 0) continue; if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; - // if encoder, get last token; just decoder, get first token. - if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; - const int input_token_id = ori_token_id - cum_offset[bi] + seq_id; + if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; + const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi]; + const int input_token_id = ori_token_id - cum_offset_bi + seq_id; const int bias_idx = i % dim_embed; Load(&input_data[input_token_id * dim_embed + bias_idx], @@ -78,10 +81,11 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, } } + template std::vector rebuild_padding( const paddle::Tensor &tmp_out, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1] const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, @@ -100,7 +104,7 @@ std::vector rebuild_padding( std::vector tmp_out_shape = tmp_out.shape(); const int token_num = tmp_out_shape[0]; const int dim_embed = tmp_out_shape[1]; - const int bsz = cum_offsets.shape()[0]; + const int bsz = cu_seqlens_q.shape()[0] - 1; paddle::Tensor out; if (output_padding_offset) { @@ -133,21 +137,22 @@ std::vector rebuild_padding( <<>>( reinterpret_cast(out.data()), reinterpret_cast(tmp_out.data()), - cum_offsets.data(), + cu_seqlens_q.data(), seq_len_this_time.data(), seq_lens_decoder.data(), seq_lens_encoder.data(), output_padding_offset.get_ptr()->data(), max_input_length, dim_embed, - elem_nums); + elem_nums, + bsz); } else { RebuildPaddingKernel <<>>( reinterpret_cast(out.data()), reinterpret_cast( const_cast(tmp_out.data())), - cum_offsets.data(), + cu_seqlens_q.data(), seq_len_this_time.data(), seq_lens_decoder.data(), seq_lens_encoder.data(), @@ -160,7 +165,7 @@ std::vector rebuild_padding( paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &tmp_out, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1] const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, @@ -170,7 +175,7 @@ paddle::Tensor RebuildPaddingFunc( case paddle::DataType::BFLOAT16: { return rebuild_padding( tmp_out, - cum_offsets, + cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, @@ -180,7 +185,7 @@ paddle::Tensor RebuildPaddingFunc( case paddle::DataType::FLOAT16: { return rebuild_padding( tmp_out, - cum_offsets, + cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, @@ -190,7 +195,7 @@ paddle::Tensor RebuildPaddingFunc( case paddle::DataType::FLOAT32: { return rebuild_padding( tmp_out, - cum_offsets, + cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, @@ -208,14 +213,14 @@ paddle::Tensor RebuildPaddingFunc( std::vector RebuildPadding( const paddle::Tensor &tmp_out, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1] const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, int max_input_length) { return {RebuildPaddingFunc(tmp_out, - cum_offsets, + cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, @@ -225,7 +230,7 @@ std::vector RebuildPadding( std::vector> RebuildPaddingInferShape( const std::vector &tmp_out_shape, - const std::vector &cum_offsets_shape, + const std::vector &cu_seqlens_q_shape, const std::vector &seq_len_this_time_shape, const std::vector &seq_lens_decoder_shape, const std::vector &seq_lens_encoder_shape, @@ -235,14 +240,14 @@ std::vector> RebuildPaddingInferShape( if (output_padding_offset_shape) { return {{-1, dim_embed}}; } else { - int64_t bsz = cum_offsets_shape[0]; + int64_t bsz = cu_seqlens_q_shape[0] - 1; return {{bsz, dim_embed}}; } } std::vector RebuildPaddingInferDtype( const paddle::DataType &tmp_out_dtype, - const paddle::DataType &cum_offsets_dtype, + const paddle::DataType &cu_seqlens_q_dtype, const paddle::DataType &seq_len_this_time_dtype, const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_encoder_dtype, @@ -252,7 +257,7 @@ std::vector RebuildPaddingInferDtype( PD_BUILD_STATIC_OP(rebuild_padding) .Inputs({"tmp_out", - "cum_offsets", + "cu_seqlens_q", "seq_len_this_time", "seq_lens_decoder", "seq_lens_encoder", diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu index 96186d761f..e37dacbf34 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu @@ -106,7 +106,6 @@ std::vector SpeculateGetPaddingOffset( seq_length, max_draft_tokens); return {x_remove_padding, - cum_offsets_out, batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num}; @@ -121,7 +120,7 @@ std::vector> SpeculateGetPaddingOffsetInferShape( const std::vector& seq_lens_encoder_shape) { int64_t bsz = seq_len_shape[0]; int64_t seq_len = input_ids_shape[1]; - return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; + return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector SpeculateGetPaddingOffsetInferDtype( @@ -132,7 +131,6 @@ std::vector SpeculateGetPaddingOffsetInferDtype( const paddle::DataType& seq_len_dtype, const paddle::DataType& seq_lens_encoder_dtype) { return {input_ids_dtype, - seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; @@ -141,12 +139,10 @@ std::vector SpeculateGetPaddingOffsetInferDtype( PD_BUILD_STATIC_OP(speculate_get_padding_offset) .Inputs({"input_ids", "draft_tokens", - "cum_offsets", "token_num", "seq_len", "seq_lens_encoder"}) .Outputs({"x_remove_padding", - "cum_offsets_out", "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 0794e42cfa..e3dfcbf3c5 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -104,7 +104,6 @@ def pre_process( if speculative_decoding: ( ids_remove_padding, - cum_offsets, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, @@ -134,14 +133,12 @@ def pre_process( else: ( ids_remove_padding, - cum_offsets, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) return ( ids_remove_padding, - cum_offsets, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, @@ -501,7 +498,7 @@ def step_cuda( def rebuild_padding( tmp_out: paddle.Tensor, - cum_offsets: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, seq_len_this_time: paddle.Tensor, seq_lens_decoder: paddle.Tensor, seq_lens_encoder: paddle.Tensor, @@ -517,7 +514,7 @@ def rebuild_padding( hidden_states = rebuild_padding( tmp_out, - cum_offsets, + cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, @@ -529,7 +526,7 @@ def rebuild_padding( hidden_states = rebuild_padding( tmp_out, - cum_offsets, + cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, @@ -541,19 +538,19 @@ def rebuild_padding( hidden_states = rebuild_padding( tmp_out, - cum_offsets, + cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, output_padding_offset, max_input_length, ) - elif current_platform.is_gcu(): + # elif current_platform.is_gcu(): from fastdeploy.model_executor.ops.gcu import rebuild_padding hidden_states = rebuild_padding( tmp_out, - cum_offsets, + cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, @@ -565,7 +562,7 @@ def rebuild_padding( hidden_states = rebuild_padding_cpu( tmp_out, - cum_offsets, + cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 3033e41467..9b1c3335ec 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -274,7 +274,6 @@ def _init_model_inputs(self): self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"]) self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"]) - self.model_inputs["cum_offsets"] = paddle.clone(self.main_model_inputs["cum_offsets"]) self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"]) self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"]) self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"]) @@ -527,7 +526,6 @@ def _propose(self, target_hidden_states): # Remove padding ( ids_remove_padding, - cum_offsets, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, @@ -543,7 +541,6 @@ def _propose(self, target_hidden_states): ) # Initialize forward meta data self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) - self.model_inputs["cum_offsets"].copy_(cum_offsets, False) self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) @@ -578,7 +575,7 @@ def _propose(self, target_hidden_states): hidden_states = rebuild_padding( model_output, - self.model_inputs["cum_offsets"], + self.model_inputs["cu_seqlens_q"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_decoder"], self.model_inputs["seq_lens_encoder"], diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index b52b35bc40..e572b90858 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -419,7 +419,7 @@ def _init_share_inputs(self, max_num_seqs: int): 0, dtype="int64", ) - self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") @@ -518,7 +518,6 @@ def _prepare_inputs(self) -> None: ) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) - self.share_inputs["cum_offsets"].copy_(cum_offsets, False) self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) @@ -736,7 +735,7 @@ def _dummy_run( hidden_states = rebuild_padding( model_output, - self.share_inputs["cum_offsets"], + self.share_inputs["cu_seqlens_q"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_encoder"], @@ -961,7 +960,7 @@ class at the server level, which is too granular for ModelRunner. hidden_states = rebuild_padding( model_output, - self.share_inputs["cum_offsets"], + self.share_inputs["cu_seqlens_q"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_encoder"], diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 7a149f83d8..bf2611aa8b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -663,7 +663,6 @@ def _init_share_inputs(self, max_num_seqs: int): 0, dtype="int64", ) - self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") @@ -784,7 +783,6 @@ def _prepare_inputs(self) -> None: # Remove padding ( ids_remove_padding, - cum_offsets, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, @@ -800,7 +798,6 @@ def _prepare_inputs(self) -> None: ) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) - self.share_inputs["cum_offsets"].copy_(cum_offsets, False) self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) @@ -944,7 +941,6 @@ def initialize_kv_cache(self, profile: bool = False) -> None: cache_kvs_list.append(value_cache) self.share_inputs["caches"] = cache_kvs_list - else: for i in range(self.model_config.num_hidden_layers): cache_kvs[f"key_caches_{i}"] = paddle.full( @@ -1050,7 +1046,7 @@ def _dummy_run( hidden_states = rebuild_padding( model_output, - self.share_inputs["cum_offsets"], + self.share_inputs["cu_seqlens_q"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_encoder"], @@ -1315,7 +1311,7 @@ class at the server level, which is too granular for ModelRunner. ) hidden_states = rebuild_padding( model_output, - self.share_inputs["cum_offsets"], + self.share_inputs["cu_seqlens_q"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_encoder"], @@ -1415,6 +1411,7 @@ class at the server level, which is too granular for ModelRunner. # 7. Updata 'infer_seed' and step_cuda() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: step_cuda( self.share_inputs, diff --git a/test/layers/test_rebuild_padding.py b/test/layers/test_rebuild_padding.py new file mode 100644 index 0000000000..a66bde915b --- /dev/null +++ b/test/layers/test_rebuild_padding.py @@ -0,0 +1,222 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import time +import unittest +from typing import Tuple + +import numpy as np +import paddle + + +class TestCuSeqlensQPerformance(unittest.TestCase): + + def setUp(self): + paddle.device.set_device("gpu:0") + + # Test configurations:(batch_size, max_seq_len, dim_embed, avg_seq_len_ratio) + self.test_configs = [ + # Small scale tests + (4, 512, 2048, 0.8), + (8, 512, 4096, 0.7), + # Medium scale tests + (16, 1024, 4096, 0.6), + (32, 1024, 4096, 0.8), + # Large scale tests + (64, 2048, 4096, 0.5), + (128, 1024, 8192, 0.7), + (256, 512, 4096, 0.9), + ] + + self.warmup_runs = 10 + self.benchmark_runs = 50 + + def generate_realistic_test_data( + self, batch_size: int, max_seq_len: int, dim_embed: int, avg_ratio: float + ) -> dict: + """Generate test data closer to real-world scenarios""" + + avg_seq_len = int(max_seq_len * avg_ratio) + std_seq_len = avg_seq_len // 4 + + seq_lens = np.random.normal(avg_seq_len, std_seq_len, batch_size) + seq_lens = np.clip(seq_lens, max_seq_len // 10, max_seq_len).astype(np.int32) + + total_tokens = np.sum(seq_lens) + + tmp_out = paddle.randn([total_tokens, dim_embed], dtype=paddle.float16) + tmp_out = tmp_out.cuda() + + cu_seqlens_q_np = np.zeros(batch_size + 1, dtype=np.int32) + for i in range(batch_size): + cu_seqlens_q_np[i + 1] = cu_seqlens_q_np[i] + seq_lens[i] + + cu_seqlens_q = paddle.to_tensor(cu_seqlens_q_np, dtype=paddle.int32).cuda() + + seq_len_this_time = paddle.to_tensor(seq_lens, dtype=paddle.int32).cuda() + seq_len_decoder = paddle.to_tensor(seq_lens, dtype=paddle.int32).cuda() + seq_len_encoder = paddle.zeros([batch_size], dtype=paddle.int32).cuda() + + return { + "tmp_out": tmp_out, + "cu_seqlens_q": cu_seqlens_q, + "seq_len_this_time": seq_len_this_time, + "seq_len_decoder": seq_len_decoder, + "seq_len_encoder": seq_len_encoder, + "max_input_length": max_seq_len, + "actual_tokens": total_tokens, + "seq_lens": seq_lens, + } + + def benchmark_cu_seqlens_performance(self, data_dict: dict) -> Tuple[float, float, paddle.Tensor]: + """Test performance of cu_seqlens_q version""" + + def rebuild_padding_cu_seqlens( + tmp_out, cu_seqlens_q, seq_len_this_time, seq_len_decoder, seq_len_encoder, max_input_length + ): + + from fastdeploy.model_executor.pre_and_post_process import rebuild_padding + + hidden_states = rebuild_padding( + tmp_out, cu_seqlens_q, seq_len_this_time, seq_len_decoder, seq_len_encoder, None, max_input_length + ) + return hidden_states + + for _ in range(self.warmup_runs): + result = rebuild_padding_cu_seqlens( + data_dict["tmp_out"], + data_dict["cu_seqlens_q"], + data_dict["seq_len_this_time"], + data_dict["seq_len_decoder"], + data_dict["seq_len_encoder"], + data_dict["max_input_length"], + ) + paddle.device.cuda.synchronize() + + paddle.device.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(self.benchmark_runs): + result = rebuild_padding_cu_seqlens( + data_dict["tmp_out"], + data_dict["cu_seqlens_q"], + data_dict["seq_len_this_time"], + data_dict["seq_len_decoder"], + data_dict["seq_len_encoder"], + data_dict["max_input_length"], + ) + + paddle.device.cuda.synchronize() + end_time = time.perf_counter() + + avg_time = (end_time - start_time) / self.benchmark_runs * 1000 # ms + + # throughput(tokens/ms) + throughput = data_dict["actual_tokens"] / avg_time + + return avg_time, throughput, result + + def test_performance_scaling(self): + """Test performance unfer different scales""" + print("\n" + "=" * 90) + print("CU_SEQLENS_Q Performance Scaling Test") + print("=" * 90) + print( + f"{'Config':<20} {'Batch':<6} {'SeqLen':<7} {'Tokens':<8} {'Time(ms)':<10} {'Throughput':<12} {'Memory(MB)'}" + ) + print("-" * 90) + + results = [] + + for i, (batch_size, max_seq_len, dim_embed, avg_ratio) in enumerate(self.test_configs): + config_name = f"Config_{i+1}" + + try: + data_dict = self.generate_realistic_test_data(batch_size, max_seq_len, dim_embed, avg_ratio) + + paddle.device.cuda.empty_cache() + mem_before = paddle.device.cuda.memory_allocated() / 1024 / 1024 # MB + + avg_time, throughput, result = self.benchmark_cu_seqlens_performance(data_dict) + + mem_after = paddle.device.cuda.memory_allocated() / 1024 / 1024 # MB + mem_usage = mem_after - mem_before + + results.append( + { + "config": config_name, + "batch_size": batch_size, + "max_seq_len": max_seq_len, + "dim_embed": dim_embed, + "actual_tokens": data_dict["actual_tokens"], + "avg_time": avg_time, + "throughput": throughput, + "memory_mb": mem_usage, + "result_shape": result.shape, + } + ) + + print( + f"{config_name:<20} {batch_size:<6} {max_seq_len:<7} " + f"{data_dict['actual_tokens']:<8} {avg_time:<10.3f} " + f"{throughput:<12.1f} {mem_usage:<8.1f}" + ) + + expected_shape = [batch_size, dim_embed] + self.assertEqual(list(result.shape), expected_shape, f"Output shape mismatch for {config_name}") + + except Exception as e: + raise RuntimeError( + f"Failed to test configuration {config_name} (batch={batch_size}, seq_len={max_seq_len}): {str(e)}" + ) + + print("-" * 90) + return results + + +def main(): + """Run all performance tests""" + print("Starting CU_SEQLENS_Q Performance Benchmark...") + print(f"GPU: {paddle.device.cuda.get_device_name()}") + print(f"GPU Memory: {paddle.device.cuda.get_device_properties().total_memory / 1024**3:.1f} GB") + + test_instance = TestCuSeqlensQPerformance() + test_instance.setUp() + + try: + scaling_results = test_instance.test_performance_scaling() + + print("\n" + "=" * 50) + print("Performance Summary") + print("=" * 50) + + if scaling_results: + best_throughput = max(scaling_results, key=lambda x: x["throughput"]) + print(f"Best throughput: {best_throughput['throughput']:.1f} tokens/ms") + print( + f" Config: {best_throughput['config']} " + f"(batch={best_throughput['batch_size']}, " + f"seq_len={best_throughput['max_seq_len']})" + ) + + print("=" * 50) + + except Exception as e: + raise RuntimeError(f"Performance benchmark failed: {str(e)}") + + +if __name__ == "__main__": + main()