From a46aa0619439f7e52030cfa8227144d1463df8e1 Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Thu, 25 Sep 2025 20:11:57 +0800 Subject: [PATCH 01/10] support logprob in mtp --- custom_ops/gpu_ops/cpp_extensions.cc | 54 ++- custom_ops/gpu_ops/rebuild_padding.cu | 81 ++++- .../speculate_get_output_with_topk.cc | 141 ++++++++ .../speculate_logprob_utils.cu | 285 ++++++++++++++++ .../speculate_save_output_with_topk.cc | 209 ++++++++++++ fastdeploy/engine/args_utils.py | 4 +- .../model_executor/layers/sample/sampler.py | 311 +++++++++++++++++- .../model_executor/pre_and_post_process.py | 4 + fastdeploy/output/token_processor.py | 41 ++- fastdeploy/spec_decode/mtp.py | 69 +++- fastdeploy/worker/output.py | 1 + 11 files changed, 1173 insertions(+), 27 deletions(-) create mode 100644 custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc create mode 100644 custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu create mode 100644 custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 1ced2ce6fb..2409b9553a 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -332,7 +332,9 @@ paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length); + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob); void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &stop_flags, @@ -891,6 +893,46 @@ void SaveOutMmsgStatic(const paddle::Tensor& x, int64_t rank_id, bool save_each_rank); +std::vector SpeculateGetLogits( + const paddle::Tensor &logits, + const paddle::Tensor &first_token_logits, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder); + +void SpeculateInsertFirstToken(const paddle::Tensor &token_ids, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &next_tokens, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &cu_batch_token_offset, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder); + +void SpeculateGetTargetLogits(const paddle::Tensor &target_logits, + const paddle::Tensor &logits, + const paddle::Tensor &cu_batch_token_offset, + const paddle::Tensor &ori_cu_batch_token_offset, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &accept_num); + +void SpeculateSaveOutMmsgTopK(const paddle::Tensor &sampled_token_ids, + const paddle::Tensor &logprob_token_ids, + const paddle::Tensor &logprob_scores, + const paddle::Tensor &logprob_ranks, + const paddle::Tensor &token_num_per_batch, + const paddle::Tensor &cu_batch_token_offset, + const paddle::Tensor ¬_need_stop, + int mtype, + int64_t rank_id); + +void SpeculateGetOutMmsgTopK(const paddle::Tensor &output_tokens, + const paddle::Tensor &output_scores, + const paddle::Tensor &output_ranks, + int real_k, + int64_t rank_id, + bool wait_flag); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -1277,4 +1319,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function"); m.def("save_output", &SaveOutMmsgStatic, "save_output function"); + + m.def("speculate_get_logits", &SpeculateGetLogits, "speculate_get_logits function"); + + m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function"); + + m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function"); + + m.def("speculate_save_output_topk", &SpeculateSaveOutMmsgTopK, "speculate_save_output_topk function"); + + m.def("speculate_get_output_topk", &SpeculateGetOutMmsgTopK, "speculate_get_output_topk function"); } diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index 93c1bb38c2..381524ef5e 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -46,6 +46,7 @@ __global__ void RebuildPaddingKernel(T *output_data, template __global__ void RebuildAppendPaddingKernel(T *output_data, + T *first_token_out, const T *input_data, const int *cu_seqlens_q, const int *seq_len_this_time, @@ -55,7 +56,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, const int max_input_length, const int dim_embed, const int64_t output_elem_nums, - const int bsz) { + const int bsz, + const bool enable_logprob) { AlignedVector src_vec; const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; for (int64_t i = global_idx * VecSize; i < output_elem_nums; @@ -77,6 +79,35 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, Load(&input_data[input_token_id * dim_embed + bias_idx], &src_vec); Store(src_vec, &output_data[i]); + + // printf( + // "[normal] out_token_id: %d, ori_token_id: %d, input_token_id: %d " + // "bias_idx: %d, bid: %d, seq_id: %d\n", + // out_token_id, + // ori_token_id, + // input_token_id, + // bias_idx, + // bi, + // seq_id); + + if (enable_logprob && seq_len_encoder[bi] > 0) { + int first_token_seq_id = seq_len_encoder[bi] - 2; + const int first_token_id = + ori_token_id - cum_offset_bi + first_token_seq_id; + // printf( + // "[first token] out_token_id: %d, ori_token_id: %d, " + // "first_token_id: %d, bias_idx: %d, bid: %d, " + // "first_token_seq_id: %d\n", + // out_token_id, + // ori_token_id, + // first_token_id, + // bias_idx, + // bi, + // first_token_seq_id); + Load(&input_data[first_token_id * dim_embed + bias_idx], + &src_vec); + Store(src_vec, &first_token_out[i]); + } } } @@ -89,7 +120,9 @@ std::vector rebuild_padding( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -120,6 +153,9 @@ std::vector rebuild_padding( 0, D, tmp_out.place()); + // printf("token_num: %d, need_delete_token_num: %d\n", + // token_num, + // need_delete_token_num); } else { out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place()); @@ -130,11 +166,20 @@ std::vector rebuild_padding( int pack_num = elem_nums / PackSize; const int blocksize = 128; const int grid_size = (pack_num + blocksize - 1) / blocksize; + printf("elem_nums: %d\n", elem_nums); if (output_padding_offset) { + // if (first_token_out.is_initialized()) { + // printf("first_token_out is initialized, enable_logprob: %d\n", + // enable_logprob); + // } RebuildAppendPaddingKernel <<>>( reinterpret_cast(out.data()), + first_token_out.is_initialized() + ? reinterpret_cast(const_cast( + first_token_out.get_ptr()->data())) + : nullptr, reinterpret_cast(tmp_out.data()), cu_seqlens_q.data(), seq_len_this_time.data(), @@ -144,7 +189,8 @@ std::vector rebuild_padding( max_input_length, dim_embed, elem_nums, - bsz); + bsz, + enable_logprob); } else { RebuildPaddingKernel <<>>( @@ -169,7 +215,9 @@ paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob) { switch (tmp_out.type()) { case paddle::DataType::BFLOAT16: { return rebuild_padding( @@ -179,7 +227,9 @@ paddle::Tensor RebuildPaddingFunc( seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)[0]; + first_token_out, + max_input_length, + enable_logprob)[0]; } case paddle::DataType::FLOAT16: { return rebuild_padding( @@ -189,7 +239,9 @@ paddle::Tensor RebuildPaddingFunc( seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)[0]; + first_token_out, + max_input_length, + enable_logprob)[0]; } case paddle::DataType::FLOAT32: { return rebuild_padding( @@ -199,7 +251,9 @@ paddle::Tensor RebuildPaddingFunc( seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)[0]; + first_token_out, + max_input_length, + enable_logprob)[0]; } default: { PD_THROW( @@ -217,14 +271,18 @@ std::vector RebuildPadding( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob) { return {RebuildPaddingFunc(tmp_out, cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)}; + first_token_out, + max_input_length, + enable_logprob)}; } std::vector> RebuildPaddingInferShape( @@ -260,9 +318,10 @@ PD_BUILD_STATIC_OP(rebuild_padding) "seq_len_this_time", "seq_lens_decoder", "seq_lens_encoder", - paddle::Optional("output_padding_offset")}) + paddle::Optional("output_padding_offset"), + paddle::Optional("first_token_out")}) .Outputs({"out"}) - .Attrs({"max_input_length: int"}) + .Attrs({"max_input_length: int", "enable_logprob: bool"}) .SetKernelFn(PD_KERNEL(RebuildPadding)) .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc new file mode 100644 index 0000000000..922198f953 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +#define MAX_BSZ 512 +#define K 20 +#define MAX_DRAFT_TOKEN_NUM 6 +#define SPECULATE_GET_WITH_OUTPUT_DEBUG + +struct batch_msgdata { + int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; +}; + +struct msgdata { + int meta[3 + MAX_BSZ]; // stop_flag, mtype, bsz, batch_token_nums + batch_msgdata mtext[MAX_BSZ]; +}; + +void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, + const paddle::Tensor& output_scores, + const paddle::Tensor& output_ranks, + int real_k, + int64_t rank_id, + bool wait_flag) { + static struct msgdata msg_rcv; + int msg_queue_id = 1; + + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str( + inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); +#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + msg_queue_id = inference_msg_queue_id_from_env; + } + static key_t key = ftok("/dev/shm", msg_queue_id); + + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG + std::cout << "get_output_key: " << key << std::endl; + std::cout << "get_output msgid: " << msgid << std::endl; +#endif + + int64_t* output_tokens_data = + const_cast(output_tokens.data()); + float* output_scores_data = const_cast(output_scores.data()); + int64_t* output_ranks_data = + const_cast(output_ranks.data()); + int ret = -1; + if (!wait_flag) { + ret = msgrcv( + msgid, + &msg_rcv, + (3 + MAX_BSZ) * sizeof(int) + + MAX_BSZ * ((MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(int) + + (MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(float) + + MAX_DRAFT_TOKEN_NUM * sizeof(int)), + 0, + IPC_NOWAIT); + } else { + ret = msgrcv( + msgid, + &msg_rcv, + (3 + MAX_BSZ) * sizeof(int) + + MAX_BSZ * ((MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(int) + + (MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(float) + + MAX_DRAFT_TOKEN_NUM * sizeof(int)), + 0, + 0); + } + if (ret == -1) { + // read none + output_tokens_data[0] = -2; // stop_flag + output_tokens_data[1] = msg_rcv.meta[1]; // mtype, Target: 3, Draft: 4 + output_tokens_data[2] = 0; // bsz + return; + } + + int bsz = msg_rcv.meta[1]; + output_tokens_data[0] = msg_rcv.meta[0]; + output_tokens_data[1] = msg_rcv.meta[1]; + output_tokens_data[2] = msg_rcv.meta[2]; + + int output_tokens_offset = 3 + MAX_BSZ; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_rcv.meta[3 + i]; + output_tokens_data[3 + i] = cur_token_num; // batch_token_nums + + auto* cur_output_token = output_tokens_data + output_tokens_offset + + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + auto* cur_output_score = + output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; + for (int j = 0; j < cur_token_num; j++) { + for (int k = 0; k < real_k + 1; k++) { + cur_output_token[j * (K + 1) + k] = + cur_batch_msg_rcv->tokens[k]; + cur_output_score[j * (K + 1) + k] = + cur_batch_msg_rcv->scores[k]; + } + output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] = + cur_batch_msg_rcv->ranks[j]; + } + } +} + +PD_BUILD_STATIC_OP(speculate_get_output_topk) + .Inputs({"output_tokens", "output_scores", "output_ranks"}) + .Attrs({"real_k: int", "rank_id: int64_t", "wait_flag: bool"}) + .Outputs({"output_tokens_out", "output_scores_out", "output_ranks_out"}) + .SetInplaceMap({{"output_tokens", "output_tokens_out"}, + {"output_scores", "output_scores_out"}, + {"output_ranks", "output_ranks_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetOutMmsgTopK)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu new file mode 100644 index 0000000000..f40e099a1e --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu @@ -0,0 +1,285 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +template +__global__ void get_token_num_per_batch_kernel(int* batch_token_num, + int* total_token_num, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz) { + int bid = threadIdx.x; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int token_num_now = 0; + if (bid < real_bsz) { + token_num_now = seq_lens_encoder[bid] > 0 ? 2 : seq_lens_this_time[bid]; + batch_token_num[bid] = token_num_now; + } + + __syncthreads(); + int token_num_sum = BlockReduce(temp_storage).Sum(token_num_now); + if (bid == 0) { + total_token_num[0] = token_num_sum; + } +} + +template +__global__ void speculate_get_logits_kernel(float* draft_logits, + const float* logits, + const float* first_token_logits, + const int* cu_seqlens_q, + const int* cu_batch_token_offset, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int vocab_size, + const int real_bsz) { + AlignedVector src_vec; + const int bid = blockIdx.x; + const int tid = threadIdx.x; + if (bid < real_bsz) { + auto* draft_logits_now = + draft_logits + cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + cu_seqlens_q[bid] * vocab_size; + for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) { + if (seq_lens_encoder[bid] > 0) { + Load(&first_token_logits[bid * vocab_size + i], + &src_vec); + Store(src_vec, &draft_logits_now[i]); + + Load(&logits_now[i], &src_vec); + Store(src_vec, + &draft_logits_now[vocab_size + i]); + } else { + for (int j = 0; j < seq_lens_this_time[bid]; j++) { + Load(&logits_now[j * vocab_size + i], + &src_vec); + Store( + src_vec, &draft_logits_now[j * vocab_size + i]); + } + } + } + } +} + +std::vector SpeculateGetLogits( + const paddle::Tensor& logits, + const paddle::Tensor& first_token_logits, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder) { + auto cu_stream = seq_lens_this_time.stream(); + const int vocab_size = logits.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + auto total_token_num = paddle::full( + {1}, 0, paddle::DataType::INT32, seq_lens_this_time.place()); + auto batch_token_num = paddle::full( + {real_bsz}, 0, paddle::DataType::INT32, seq_lens_this_time.place()); + + constexpr int THREADBLOCK_SIZE = 512; + get_token_num_per_batch_kernel + <<<1, THREADBLOCK_SIZE, 0, cu_stream>>>(batch_token_num.data(), + total_token_num.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + real_bsz); + + auto total_token_num_cpu = + total_token_num.copy_to(paddle::CPUPlace(), true); + + auto draft_logits = + paddle::empty({total_token_num_cpu.data()[0], vocab_size}, + paddle::DataType::FLOAT32, + seq_lens_this_time.place()); + auto cu_batch_token_offset = paddle::full( + {real_bsz + 1}, 0, paddle::DataType::INT32, seq_lens_this_time.place()); + + void* temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum(temp_storage, + temp_storage_bytes, + batch_token_num.data(), + &cu_batch_token_offset.data()[1], + real_bsz, + cu_stream); + cudaMalloc(&temp_storage, temp_storage_bytes); + cub::DeviceScan::InclusiveSum(temp_storage, + temp_storage_bytes, + batch_token_num.data(), + &cu_batch_token_offset.data()[1], + real_bsz, + cu_stream); + + constexpr int PackSize = VEC_16B / sizeof(float); + dim3 grid_dim(real_bsz); + dim3 block_dim(128); + speculate_get_logits_kernel + <<>>( + const_cast(draft_logits.data()), + logits.data(), + first_token_logits.data(), + cu_seqlens_q.data(), + cu_batch_token_offset.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + vocab_size, + real_bsz); + + return {draft_logits, batch_token_num, cu_batch_token_offset}; +} + +__global__ void speculate_insert_first_token_kernel( + int64_t* token_ids, + const int64_t* accept_tokens, + const int64_t* next_tokens, + const int* cu_seqlens_q, + const int* cu_batch_token_offset, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int max_draft_tokens, + const int real_bsz) { + const int bid = threadIdx.x; + + auto* token_ids_now = token_ids + cu_batch_token_offset[bid]; + auto* accept_tokens_now = accept_tokens + bid * max_draft_tokens; + auto* next_tokens_now = next_tokens + cu_seqlens_q[bid]; + if (seq_lens_encoder[bid] != 0) { + token_ids_now[0] = accept_tokens_now[0]; + token_ids_now[1] = next_tokens_now[0]; + } else { + for (int i = 0; i < seq_lens_this_time[bid]; i++) { + token_ids_now[i] = next_tokens_now[i]; + } + } +} + +void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& next_tokens, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder) { + auto cu_stream = seq_lens_this_time.stream(); + const int max_draft_tokens = accept_tokens.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + speculate_insert_first_token_kernel<<<1, real_bsz, 0, cu_stream>>>( + const_cast(token_ids.data()), + accept_tokens.data(), + next_tokens.data(), + cu_seqlens_q.data(), + cu_batch_token_offset.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_draft_tokens, + real_bsz); +} + +template +__global__ void speculate_get_target_logits_kernel( + float* target_logtis, + const float* logits, + const int* cu_batch_token_offset, + const int* ori_cu_batch_token_offset, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* accept_num, + const int vocab_size, + const int real_bsz) { + AlignedVector src_vec; + const int bid = blockIdx.x; + const int tid = threadIdx.x; + if (bid < real_bsz) { + auto* target_logtis_now = + target_logtis + cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size; + for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) { + if (seq_lens_encoder[bid] > 0) { + Load(&logits_now[i], &src_vec); + Store(src_vec, &target_logtis_now[i]); + } else { + for (int j = 0; j < accept_num[bid]; j++) { + Load(&logits_now[j * vocab_size + i], + &src_vec); + Store( + src_vec, &target_logtis_now[j * vocab_size + i]); + } + } + } + } +} + +void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& ori_cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num) { + auto cu_stream = seq_lens_this_time.stream(); + const int vocab_size = logits.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + constexpr int PackSize = VEC_16B / sizeof(float); + dim3 grid_dim(real_bsz); + dim3 block_dim(128); + speculate_get_target_logits_kernel + <<>>( + const_cast(target_logits.data()), + logits.data(), + cu_batch_token_offset.data(), + ori_cu_batch_token_offset.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + accept_num.data(), + vocab_size, + real_bsz); +} + +PD_BUILD_STATIC_OP(speculate_get_logits) + .Inputs({"logits", + "first_token_logits", + "cu_seqlens_q", + "seq_lens_this_time", + "seq_lens_encoder"}) + .Outputs({"draft_logits", "batch_token_num", "cu_batch_token_offset"}) + .SetKernelFn(PD_KERNEL(SpeculateGetLogits)); + +PD_BUILD_STATIC_OP(speculate_insert_first_token) + .Inputs({"token_ids", + "accept_tokens", + "next_tokens", + "cu_seqlens_q", + "cu_batch_token_offset", + "seq_lens_this_time", + "seq_lens_encoder"}) + .Outputs({"token_ids_out"}) + .SetInplaceMap({{"token_ids", "token_ids_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateInsertFirstToken)); + +PD_BUILD_STATIC_OP(speculate_get_target_logits) + .Inputs({"target_logits", + "logits", + "cu_batch_token_offset", + "ori_cu_batch_token_offset", + "seq_lens_this_time", + "seq_lens_encoder", + "accept_num"}) + .Outputs({"target_logits_out"}) + .SetInplaceMap({{"target_logits", "target_logits_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetTargetLogits)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc new file mode 100644 index 0000000000..10f547cea0 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -0,0 +1,209 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +#define MAX_BSZ 512 +#define K 20 +#define MAX_DRAFT_TOKEN_NUM 6 +#define SPECULATE_SAVE_WITH_OUTPUT_DEBUG + +struct batch_msgdata { + int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; +}; + +struct msgdata { + int meta[3 + MAX_BSZ]; // stop_flag, mtype, bsz, batch_token_nums + batch_msgdata mtext[MAX_BSZ]; +}; + +void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, + const paddle::Tensor& logprob_token_ids, + const paddle::Tensor& logprob_scores, + const paddle::Tensor& logprob_ranks, + const paddle::Tensor& token_num_per_batch, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& not_need_stop, + int mtype, // Target: 3, Draft: 4 + int64_t rank_id) { + if (rank_id > 0) { + return; + } + auto sampled_token_ids_cpu = + sampled_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_token_ids_cpu = + logprob_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false); + auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false); + auto token_num_per_batch_cpu = + token_num_per_batch.copy_to(paddle::CPUPlace(), false); + auto cu_batch_token_offset_cpu = + cu_batch_token_offset.copy_to(paddle::CPUPlace(), false); + int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data(); + int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data(); + float* logprob_scores_data = logprob_scores_cpu.data(); + int64_t* logprob_ranks_data = logprob_ranks_cpu.data(); + int* token_num_per_batch_data = token_num_per_batch_cpu.data(); + int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data(); + + static struct msgdata msg_sed; + int msg_queue_id = 1; + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str( + inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); + msg_queue_id = inference_msg_queue_id_from_env; +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." + << std::endl; +#endif + } + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is perserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout + << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } + static key_t key = ftok("/dev/shm", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output_key: " << key << std::endl; + std::cout << "save msgid: " << msgid << std::endl; +#endif + + msg_sed.meta[0] = not_need_stop.data()[0] + ? inference_msg_id_from_env + : -inference_msg_id_from_env; + msg_sed.meta[1] = mtype; + int bsz = token_num_per_batch.shape()[0]; + msg_sed.meta[2] = bsz; + int max_num_logprobs = logprob_token_ids.shape()[1]; + for (int i = 0; i < bsz; i++) { + int cur_token_num = token_num_per_batch_data[i]; + msg_sed.meta[3 + i] = cur_token_num; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + int token_offset = cu_batch_token_offset_data[i]; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + std::cout << "token_offset: " << token_offset << std::endl; + for (int k = 0; k < K + 1; k++) { + if (k == 0) { + cur_tokens[k] = + (int)sampled_token_ids_data[token_offset + j]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (K + 1) + k]; + } else if (k < max_num_logprobs) { + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * (K + 1) + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (K + 1) + k]; + } else { + cur_tokens[k] = -1; + cur_scores[k] = 0.0; + } + } + cur_batch_msg_sed->ranks[j] = + (int)logprob_ranks_data[token_offset + j]; + } + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "msg data: " << std::endl; + std::cout << "stop_flag: " << msg_sed.meta[0] + << ", mtype: " << msg_sed.meta[1] << ", bsz: " << msg_sed.meta[2] + << std::endl; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_sed.meta[3 + i]; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + std::cout << "batch " << i << " token_num: " << cur_token_num + << std::endl; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + std::cout << "tokens: "; + for (int k = 0; k < K + 1; k++) { + std::cout << cur_tokens[k] << " "; + } + std::cout << std::endl; + std::cout << "scores: "; + for (int k = 0; k < K + 1; k++) { + std::cout << cur_scores[k] << " "; + } + std::cout << std::endl; + std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl; + } + } + std::cout << std::endl; +#endif + if ((msgsnd(msgid, + &msg_sed, + (3 + MAX_BSZ) * sizeof(int) + + MAX_BSZ * ((MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(int) + + (MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(float) + + MAX_DRAFT_TOKEN_NUM * sizeof(int)), + 0)) == -1) { + printf("full msg buffer\n"); + } +} + +PD_BUILD_STATIC_OP(speculate_save_output_topk) + .Inputs({ + "sampled_token_ids", + "logprob_token_ids", + "logprob_scores", + "logprob_ranks", + "token_num_per_batch", + "cu_batch_token_offset", + "not_need_stop", + }) + .Attrs({"mtype: int", "rank_id: int64_t"}) + .SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK)); diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 2f3710bfa4..c08128317b 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -403,8 +403,8 @@ def __post_init__(self): if self.dynamic_load_weight: self.enable_prefix_caching = False if self.enable_logprob: - if self.speculative_config is not None: - raise NotImplementedError("Logprob does not support speculation_config.") + # if self.speculative_config is not None: + # raise NotImplementedError("Logprob does not support speculation_config.") if not current_platform.is_cuda(): raise NotImplementedError("Only CUDA platform supports logprob.") if self.splitwise_role != "mixed": diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 5aecfa1f9e..bc1092caf5 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -36,6 +36,10 @@ min_p_sampling, top_k_top_p_sampling, ) +from fastdeploy.model_executor.ops.gpu import ( + speculate_get_target_logits, + speculate_insert_first_token, +) from fastdeploy.platforms import current_platform from fastdeploy.worker.output import LogprobsTensors, SamplerOutput @@ -286,8 +290,11 @@ def gather_logprobs( # Get with the logprob of the prompt or sampled token. token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) + print(f"[Sampler] logprobs: {logprobs}") + print(f"[Sampler] token_logprobs: {token_logprobs}") # Compute the ranks of the actual token. token_ranks = (logprobs >= token_logprobs).sum(-1) + print(f"[Sampler] token_ranks: {token_ranks}") if num_logprobs >= 1: # Find the topK values. @@ -356,6 +363,7 @@ def forward_cuda( sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, ) + print(f"[Sampler] sampler_output: {sampler_output}") return sampler_output @@ -375,6 +383,7 @@ def __init__(self, fd_config: FDConfig): self.speculative_verify_window = fd_config.speculative_config.verify_window self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode + self.speculative_tokens_num = fd_config.speculative_config.num_speculative_tokens def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" @@ -389,6 +398,104 @@ def apply_logits_processor( """apply logits processor to sampler""" pass + def compute_logprobs( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + ) -> paddle.Tensor: + """compute logprobs""" + share_inputs = sampling_metadata.share_inputs + last_logits = logits + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + print(f"[SpeculativeSampler][compute] seq_lens_this_time: {share_inputs['seq_lens_this_time']}") + print(f"[SpeculativeSampler][compute] seq_lens_encoder: {share_inputs['seq_lens_encoder']}") + batch_token_num = share_inputs["batch_token_num"] + + print(f"[SpeculativeSampler][compute] batch_token_num: {batch_token_num}") + temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs + top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs + if temp_scaled_logprobs is not None: + real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] + temperature = sampling_metadata.temperature[:real_bsz] + real_bsz_temp_scaled = ( + real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool") + ) + temperature = temperature.squeeze(1).repeat_interleave(batch_token_num) + temp_temperature = paddle.where( + real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) + ).unsqueeze(1) + last_logits = last_logits / temp_temperature + + last_logprobs = F.log_softmax(last_logits, axis=-1) + top_p_logprob = None + top_p_token_mask = None + + if top_p_normalized_logprobs is not None and share_inputs is not None: + real_token_top_p = ( + sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1) + ) + top_p_normalized_logprobs = ( + top_p_normalized_logprobs[:real_bsz] + .astype("int32") + .squeeze(1) + .repeat_interleave(batch_token_num) + .astype("bool") + .unsqueeze(1) + ) + top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) + if top_p_token_mask.any(): + probs = F.softmax(last_logits, axis=-1) + probs = top_p_normalize_probs_paddle(probs, real_token_top_p) + top_p_logprob = paddle.log(probs) + if top_p_logprob is not None: + last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) + return last_logprobs + + def gather_logprobs( + self, + logprobs: paddle.Tensor, + num_logprobs: int, + token_ids: paddle.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + Args: + logprobs: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + Must be int64. + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + assert token_ids.dtype == paddle.int64 + token_ids = token_ids.unsqueeze(1) + logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) + # Get with the logprob of the prompt or sampled token. + token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) + + print(f"[SpeculativeSampler] logprobs: {logprobs}") + print(f"[SpeculativeSampler] token_logprobs: {token_logprobs}") + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + print(f"[SpeculativeSampler] token_ranks: {token_ranks}") + + if num_logprobs >= 1: + # Find the topK values. + topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) + indices = paddle.concat([token_ids, topk_indices], axis=1) + top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) + else: + indices = token_ids + top_logprobs = token_logprobs + + return LogprobsTensors(indices, top_logprobs, token_ranks) + def forward_cuda( self, logits: paddle.Tensor, @@ -427,6 +534,9 @@ def forward_cuda( max_model_len, ) + print(f"[SpeculativeSampler] verify_tokens: {verify_tokens}") + print(f"[SpeculativeSampler] actual_candidate_len: {actual_candidate_len}") + speculate_verify( share_inputs["accept_tokens"], share_inputs["accept_num"], @@ -452,8 +562,64 @@ def forward_cuda( True, # enable_topp self.speculative_benchmark_mode, ) + print(f"[SpeculativeSampler] accept_num: {share_inputs['accept_num']}") + print(f"[SpeculativeSampler] accept_tokens: {share_inputs['accept_tokens']}") + + print(f"[SpeculativeSampler] logits: {logits}") + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None: + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + batch_token_num = paddle.where( + share_inputs["seq_lens_encoder"][:real_bsz] != 0, + paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), + share_inputs["accept_num"][:real_bsz].unsqueeze(1), + ).squeeze(1) + share_inputs["batch_token_num"] = batch_token_num + print(f"[SpeculativeSampler] batch_token_num: {share_inputs['batch_token_num']}") + ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( + "int32" + ) + cu_batch_token_offset = paddle.concat( + [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"])] + ).astype("int32") + print(f"[SpeculativeSampler] ori_cu_batch_token_offset: {ori_cu_batch_token_offset}") + print(f"[SpeculativeSampler] cu_batch_token_offset: {cu_batch_token_offset}") + target_logtis = paddle.empty([share_inputs["accept_num"].sum(), logits.shape[1]], dtype=logits.dtype) + speculate_get_target_logits( + target_logtis, + logits, + cu_batch_token_offset, + ori_cu_batch_token_offset, + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + share_inputs["accept_num"], + ) + print(f"[SpeculativeSampler] target_logtis: {target_logtis}") + raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata) + print(f"[SpeculativeSampler] raw_logprobs: {raw_logprobs}") + + sampler_output = None + if num_logprobs is not None: - return None + token_ids = share_inputs["accept_tokens"] + token_ids = paddle.concat( + [ + share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] + for i in range(share_inputs["accept_num"].shape[0]) + ] + ) + print(f"[SpeculativeSampler] token_ids: {token_ids}") + logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + + sampler_output = SamplerOutput( + sampled_token_ids=token_ids, + logprobs_tensors=logprobs_tensors, + token_num_per_batch=batch_token_num, + ) + + print(f"[SpeculativeSampler] sampler_output: {sampler_output}") + + return sampler_output class MTPSampler(nn.Layer): @@ -466,6 +632,7 @@ def __init__(self, fd_config: FDConfig): self.forward = self.forward_cuda else: raise NotImplementedError + self.speculative_tokens_num = fd_config.speculative_config.num_speculative_tokens def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" @@ -480,6 +647,115 @@ def apply_logits_processor( """apply logits processor to sampler""" pass + def compute_logprobs( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + ) -> paddle.Tensor: + """compute logprobs""" + share_inputs = sampling_metadata.share_inputs + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + last_logits = logits + # print(f"[MTPSampler][compute] real_bsz: {real_bsz}") + temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs + top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs + if temp_scaled_logprobs is not None: + real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] + temperature = sampling_metadata.temperature[:real_bsz] + real_bsz_temp_scaled = ( + real_bsz_temp_scaled.astype("int32") + .squeeze(1) + .repeat_interleave(share_inputs["batch_token_num"]) + .astype("bool") + ) + temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"]) + # print(f"[MTPSampler][compute] real_bsz_temp_scaled: {real_bsz_temp_scaled}") + # print(f"[MTPSampler][compute] temperature: {temperature}") + temp_temperature = paddle.where( + real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) + ).unsqueeze(1) + # print(f"[MTPSampler][compute] temp_temperature: {temp_temperature}") + last_logits = last_logits / temp_temperature + # print(f"[MTPSampler][compute] last_logits: {last_logits}") + + last_logprobs = F.log_softmax(last_logits, axis=-1) + # print(f"[MTPSampler][compute] last_logits: {last_logits}") + top_p_logprob = None + top_p_token_mask = None + + if top_p_normalized_logprobs is not None and share_inputs is not None: + real_token_top_p = ( + sampling_metadata.top_p[:real_bsz] + .squeeze(1) + .repeat_interleave(share_inputs["batch_token_num"]) + .unsqueeze(1) + ) + # print(f"[MTPSampler][compute] real_token_top_p: {real_token_top_p}") + top_p_normalized_logprobs = ( + top_p_normalized_logprobs[:real_bsz] + .astype("int32") + .squeeze(1) + .repeat_interleave(share_inputs["batch_token_num"]) + .astype("bool") + .unsqueeze(1) + ) + # print(f"[MTPSampler][compute] top_p_normalized_logprobs: {top_p_normalized_logprobs}") + top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) + # print(f"[MTPSampler][compute] top_p_token_mask: {top_p_token_mask}") + + if top_p_token_mask.any(): + probs = F.softmax(last_logits, axis=-1) + # print(f"[MTPSampler][compute] probs: {probs}") + probs = top_p_normalize_probs_paddle(probs, real_token_top_p) + # print(f"[MTPSampler][compute] probs: {probs}") + top_p_logprob = paddle.log(probs) + # print(f"[MTPSampler][compute] top_p_logprob: {top_p_logprob}") + if top_p_logprob is not None: + last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) + return last_logprobs + + def gather_logprobs( + self, + logprobs: paddle.Tensor, + num_logprobs: int, + token_ids: paddle.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + Args: + logprobs: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + Must be int64. + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + assert token_ids.dtype == paddle.int64 + token_ids = token_ids.unsqueeze(1) + logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) + # Get with the logprob of the prompt or sampled token. + token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + if num_logprobs >= 1: + # Find the topK values. + topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) + indices = paddle.concat([token_ids, topk_indices], axis=1) + top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) + else: + indices = token_ids + top_logprobs = token_logprobs + + return LogprobsTensors(indices, top_logprobs, token_ranks) + def forward_cuda( self, logits: paddle.Tensor, @@ -488,6 +764,11 @@ def forward_cuda( share_inputs: List[paddle.Tensor], ) -> paddle.Tensor: """ """ + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None and share_inputs["substep"] == 0: + raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"], sampling_metadata) + print(f"[MTPSampler] raw_logprobs: {raw_logprobs}") + logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, @@ -509,4 +790,30 @@ def forward_cuda( _, next_tokens = top_k_top_p_sampling( probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list ) - return next_tokens + + sampler_output = None + if num_logprobs is not None and share_inputs["substep"] == 0: + token_ids = paddle.empty(share_inputs["batch_token_num"].sum(), dtype="int64") + speculate_insert_first_token( + token_ids, + share_inputs["accept_tokens"], + next_tokens, + share_inputs["cu_seqlens_q"], + share_inputs["cu_batch_token_offset"], + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + ) + print(f"[MTPSampler] token_ids: {token_ids}") + print(f"[MTPSampler] total_token_num: {share_inputs['batch_token_num'].sum()}") + + logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + + sampler_output = SamplerOutput( + sampled_token_ids=token_ids, + logprobs_tensors=logprobs_tensors, + token_num_per_batch=share_inputs["batch_token_num"], + ) + print(f"[MTPSampler] sampler_output: {sampler_output}") + print(f"[MTPSampler] next_tokens: {next_tokens}") + + return next_tokens, sampler_output diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 796476de7f..04ef4f009d 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -529,6 +529,8 @@ def rebuild_padding( seq_lens_encoder: paddle.Tensor, output_padding_offset: Optional[paddle.Tensor] = None, max_input_length: Optional[int] = None, + first_token_out: Optional[paddle.Tensor] = None, + enable_logprob: Optional[bool] = False, ): """ Args: @@ -544,7 +546,9 @@ def rebuild_padding( seq_lens_decoder, seq_lens_encoder, output_padding_offset, + first_token_out, max_input_length, + enable_logprob, ) elif current_platform.is_dcu(): from fastdeploy.model_executor.ops.gpu import rebuild_padding diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 73585ef776..a4262a2999 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -60,11 +60,20 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.use_logprobs = self.cfg.model_config.enable_logprob if self.speculative_decoding: - self.output_tokens = paddle.full( - shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], - fill_value=2, - dtype="int64", - ) + if self.use_logprobs: + self.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], fill_value=2, dtype="int64" + ) + self.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32" + ) + self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64") + else: + self.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) elif self.use_logprobs: self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") @@ -149,6 +158,7 @@ def process_sampling_results(self): get_output_ep, get_output_topk, speculate_get_output, + speculate_get_output_topk, ) rank_id = self.cfg.parallel_config.local_data_parallel_id @@ -156,9 +166,24 @@ def process_sampling_results(self): try: is_blocking = True if self.speculative_decoding: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) - if self.output_tokens[0] == -2: - continue + if self.use_logprobs: + speculate_get_output_topk( + self.output_tokens, + self.output_scores, + self.output_ranks, + K, + rank_id, + is_blocking, + ) + print(f"[TokenProcessor] output_tokens: {self.output_tokens}") + print(f"[TokenProcessor] output_scores: {self.output_scores}") + print(f"[TokenProcessor] output_ranks: {self.output_ranks}") + if self.output_tokens[0, 0] == -2: + continue + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if self.output_tokens[0] == -2: + continue else: if self.use_logprobs: diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 6ec6ee1906..bf2571b2c0 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -41,6 +41,8 @@ mtp_save_first_token, mtp_step_paddle, share_external_data, + speculate_get_logits, + speculate_save_output_topk, ) from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding @@ -62,6 +64,7 @@ def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs): self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps + self.enable_logprob = self.model_config.enable_logprob # [mixed, prefill, decoder] self.role = "mixed" @@ -354,6 +357,13 @@ def _init_model_inputs(self): self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" ) self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() + self.model_inputs["temp_scaled_logprobs"] = self.target_model_inputs["temp_scaled_logprobs"] + self.model_inputs["top_p_normalized_logprobs"] = self.target_model_inputs["top_p_normalized_logprobs"] + self.model_inputs["accept_num"] = self.target_model_inputs["accept_num"] + self.model_inputs["accept_tokens"] = self.target_model_inputs["accept_tokens"] + self.model_inputs["draft_logits"] = self.target_model_inputs["draft_logits"] + max_num_seqs = self.model_inputs["seq_lens_encoder"].shape[0] + self.model_inputs["first_token_hidden_states"] = paddle.full([max_num_seqs, self.model_config.hidden_size], -1) def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): @@ -616,6 +626,7 @@ def _propose(self, target_hidden_states): """ for substep in range(self.num_model_steps): if self.model_inputs["not_need_stop"]: + print(f"[MTPProposer] ******************** substep: {substep} ********************") self.model_inputs["substep"] = substep # Remove padding ( @@ -657,6 +668,10 @@ def _propose(self, target_hidden_states): min_dec_lens=self.model_inputs["min_dec_len"], bad_words_token_ids=self.model_inputs["bad_tokens"], eos_token_ids=self.model_inputs["eos_token_id"], + max_num_logprobs=20 if self.enable_logprob else None, + temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], + top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], + share_inputs=self.model_inputs, ) if self.num_model_steps > 1: @@ -667,7 +682,19 @@ def _propose(self, target_hidden_states): previous_hidden_states=target_hidden_states, forward_meta=self.forward_meta, ) - + print(f"[MTPProposer] model_output: {model_output}") + + if self.enable_logprob and substep == 0: + first_token_hidden_states = paddle.empty( + [self.max_num_seqs, self.model_config.hidden_size], dtype=model_output.dtype + ) + + print(f"[MTPProposer] cu_seqlens_q: {self.model_inputs['cu_seqlens_q']}") + print(f"[MTPProposer] seq_lens_this_time: {self.model_inputs['seq_lens_this_time']}") + print(f"[MTPProposer] seq_lens_encoder: {self.model_inputs['seq_lens_encoder']}") + print(f"[MTPProposer] seq_lens_decoder: {self.model_inputs['seq_lens_decoder']}") + print(f"[MTPProposer] output_cum_offsets: {self.model_inputs['output_cum_offsets']}") + print(f"[MTPProposer] output_padding_offset: {self.model_inputs['output_padding_offset']}") hidden_states = rebuild_padding( model_output, self.model_inputs["cu_seqlens_q"], @@ -676,18 +703,54 @@ def _propose(self, target_hidden_states): self.model_inputs["seq_lens_encoder"], self.model_inputs["output_padding_offset"], self.parallel_config.max_model_len, + first_token_hidden_states if substep == 0 else None, + self.enable_logprob if substep == 0 else False, ) + print(f"[MTPProposer] hidden_states: {hidden_states}") + print(f"[MTPProposer] first_token_hidden_states: {first_token_hidden_states}") # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) - - sampled_token_ids = self.sampler( + if self.enable_logprob and substep == 0: + first_token_logits = self.model.compute_logits(first_token_hidden_states) + print(f"[MTPProposer] logits: {logits}") + print(f"[MTPProposer] first_token_logits: {first_token_logits}") + print(f"[MTPProposer] output_padding_offset: {self.model_inputs['output_padding_offset']}") + + draft_logits, batch_token_num, cu_batch_token_offset = speculate_get_logits( + logits, + first_token_logits, + self.model_inputs["cu_seqlens_q"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + ) + self.model_inputs["draft_logits"] = draft_logits + self.model_inputs["batch_token_num"] = batch_token_num + self.model_inputs["cu_batch_token_offset"] = cu_batch_token_offset + print(f"[MTPProposer] draft_logits: {draft_logits}") + print(f"[MTPProposer] batch_token_num: {batch_token_num}") + print(f"[MTPProposer] cu_batch_token_offset: {cu_batch_token_offset}") + + sampled_token_ids, sampler_output = self.sampler( logits, self.sampling_metadata, self.max_model_len, self.model_inputs, ) + if substep == 0: + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + batch_token_num, + cu_batch_token_offset, + self.model_inputs["not_need_stop"], + 4, # mtype + self.local_rank, + ) + if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast(sampled_token_ids, 0) diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 6d820a873a..9bc35d4ee0 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -106,6 +106,7 @@ class SamplerOutput: # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding. sampled_token_ids: paddle.Tensor logprobs_tensors: Optional[LogprobsTensors] + token_num_per_batch: Optional[paddle.Tensor] @dataclass From 03ab5530d89a98697acffcb191e13ed122a5ba84 Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Thu, 25 Sep 2025 20:58:08 +0800 Subject: [PATCH 02/10] remove debug code --- custom_ops/gpu_ops/rebuild_padding.cu | 27 ------------ .../speculate_get_output_with_topk.cc | 2 +- .../speculate_save_output_with_topk.cc | 3 +- fastdeploy/engine/args_utils.py | 2 - .../model_executor/layers/sample/sampler.py | 42 +------------------ fastdeploy/output/token_processor.py | 22 ++-------- fastdeploy/spec_decode/mtp.py | 16 ------- 7 files changed, 6 insertions(+), 108 deletions(-) diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index 381524ef5e..a80db03035 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -80,30 +80,10 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, &src_vec); Store(src_vec, &output_data[i]); - // printf( - // "[normal] out_token_id: %d, ori_token_id: %d, input_token_id: %d " - // "bias_idx: %d, bid: %d, seq_id: %d\n", - // out_token_id, - // ori_token_id, - // input_token_id, - // bias_idx, - // bi, - // seq_id); - if (enable_logprob && seq_len_encoder[bi] > 0) { int first_token_seq_id = seq_len_encoder[bi] - 2; const int first_token_id = ori_token_id - cum_offset_bi + first_token_seq_id; - // printf( - // "[first token] out_token_id: %d, ori_token_id: %d, " - // "first_token_id: %d, bias_idx: %d, bid: %d, " - // "first_token_seq_id: %d\n", - // out_token_id, - // ori_token_id, - // first_token_id, - // bias_idx, - // bi, - // first_token_seq_id); Load(&input_data[first_token_id * dim_embed + bias_idx], &src_vec); Store(src_vec, &first_token_out[i]); @@ -153,9 +133,6 @@ std::vector rebuild_padding( 0, D, tmp_out.place()); - // printf("token_num: %d, need_delete_token_num: %d\n", - // token_num, - // need_delete_token_num); } else { out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place()); @@ -169,10 +146,6 @@ std::vector rebuild_padding( printf("elem_nums: %d\n", elem_nums); if (output_padding_offset) { - // if (first_token_out.is_initialized()) { - // printf("first_token_out is initialized, enable_logprob: %d\n", - // enable_logprob); - // } RebuildAppendPaddingKernel <<>>( reinterpret_cast(out.data()), diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc index 922198f953..1dc7e416cc 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -26,7 +26,7 @@ #define MAX_BSZ 512 #define K 20 #define MAX_DRAFT_TOKEN_NUM 6 -#define SPECULATE_GET_WITH_OUTPUT_DEBUG +// #define SPECULATE_GET_WITH_OUTPUT_DEBUG struct batch_msgdata { int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 10f547cea0..b345893c98 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -26,7 +26,7 @@ #define MAX_BSZ 512 #define K 20 #define MAX_DRAFT_TOKEN_NUM 6 -#define SPECULATE_SAVE_WITH_OUTPUT_DEBUG +// #define SPECULATE_SAVE_WITH_OUTPUT_DEBUG struct batch_msgdata { int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; @@ -134,7 +134,6 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, for (int j = 0; j < cur_token_num; j++) { auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; - std::cout << "token_offset: " << token_offset << std::endl; for (int k = 0; k < K + 1; k++) { if (k == 0) { cur_tokens[k] = diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index c08128317b..e75396530b 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -403,8 +403,6 @@ def __post_init__(self): if self.dynamic_load_weight: self.enable_prefix_caching = False if self.enable_logprob: - # if self.speculative_config is not None: - # raise NotImplementedError("Logprob does not support speculation_config.") if not current_platform.is_cuda(): raise NotImplementedError("Only CUDA platform supports logprob.") if self.splitwise_role != "mixed": diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index bc1092caf5..8c5e7190d1 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -290,11 +290,8 @@ def gather_logprobs( # Get with the logprob of the prompt or sampled token. token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) - print(f"[Sampler] logprobs: {logprobs}") - print(f"[Sampler] token_logprobs: {token_logprobs}") # Compute the ranks of the actual token. token_ranks = (logprobs >= token_logprobs).sum(-1) - print(f"[Sampler] token_ranks: {token_ranks}") if num_logprobs >= 1: # Find the topK values. @@ -363,7 +360,6 @@ def forward_cuda( sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, ) - print(f"[Sampler] sampler_output: {sampler_output}") return sampler_output @@ -407,11 +403,8 @@ def compute_logprobs( share_inputs = sampling_metadata.share_inputs last_logits = logits real_bsz = share_inputs["seq_lens_this_time"].shape[0] - print(f"[SpeculativeSampler][compute] seq_lens_this_time: {share_inputs['seq_lens_this_time']}") - print(f"[SpeculativeSampler][compute] seq_lens_encoder: {share_inputs['seq_lens_encoder']}") batch_token_num = share_inputs["batch_token_num"] - print(f"[SpeculativeSampler][compute] batch_token_num: {batch_token_num}") temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs if temp_scaled_logprobs is not None: @@ -479,11 +472,8 @@ def gather_logprobs( # Get with the logprob of the prompt or sampled token. token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) - print(f"[SpeculativeSampler] logprobs: {logprobs}") - print(f"[SpeculativeSampler] token_logprobs: {token_logprobs}") # Compute the ranks of the actual token. token_ranks = (logprobs >= token_logprobs).sum(-1) - print(f"[SpeculativeSampler] token_ranks: {token_ranks}") if num_logprobs >= 1: # Find the topK values. @@ -534,9 +524,6 @@ def forward_cuda( max_model_len, ) - print(f"[SpeculativeSampler] verify_tokens: {verify_tokens}") - print(f"[SpeculativeSampler] actual_candidate_len: {actual_candidate_len}") - speculate_verify( share_inputs["accept_tokens"], share_inputs["accept_num"], @@ -562,10 +549,7 @@ def forward_cuda( True, # enable_topp self.speculative_benchmark_mode, ) - print(f"[SpeculativeSampler] accept_num: {share_inputs['accept_num']}") - print(f"[SpeculativeSampler] accept_tokens: {share_inputs['accept_tokens']}") - print(f"[SpeculativeSampler] logits: {logits}") num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: real_bsz = share_inputs["seq_lens_this_time"].shape[0] @@ -575,15 +559,13 @@ def forward_cuda( share_inputs["accept_num"][:real_bsz].unsqueeze(1), ).squeeze(1) share_inputs["batch_token_num"] = batch_token_num - print(f"[SpeculativeSampler] batch_token_num: {share_inputs['batch_token_num']}") ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( "int32" ) cu_batch_token_offset = paddle.concat( [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"])] ).astype("int32") - print(f"[SpeculativeSampler] ori_cu_batch_token_offset: {ori_cu_batch_token_offset}") - print(f"[SpeculativeSampler] cu_batch_token_offset: {cu_batch_token_offset}") + share_inputs["cu_batch_token_offset"] = cu_batch_token_offset target_logtis = paddle.empty([share_inputs["accept_num"].sum(), logits.shape[1]], dtype=logits.dtype) speculate_get_target_logits( target_logtis, @@ -594,9 +576,7 @@ def forward_cuda( share_inputs["seq_lens_encoder"], share_inputs["accept_num"], ) - print(f"[SpeculativeSampler] target_logtis: {target_logtis}") raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata) - print(f"[SpeculativeSampler] raw_logprobs: {raw_logprobs}") sampler_output = None if num_logprobs is not None: @@ -608,7 +588,6 @@ def forward_cuda( for i in range(share_inputs["accept_num"].shape[0]) ] ) - print(f"[SpeculativeSampler] token_ids: {token_ids}") logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) sampler_output = SamplerOutput( @@ -617,8 +596,6 @@ def forward_cuda( token_num_per_batch=batch_token_num, ) - print(f"[SpeculativeSampler] sampler_output: {sampler_output}") - return sampler_output @@ -656,7 +633,6 @@ def compute_logprobs( share_inputs = sampling_metadata.share_inputs real_bsz = share_inputs["seq_lens_this_time"].shape[0] last_logits = logits - # print(f"[MTPSampler][compute] real_bsz: {real_bsz}") temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs if temp_scaled_logprobs is not None: @@ -669,17 +645,12 @@ def compute_logprobs( .astype("bool") ) temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"]) - # print(f"[MTPSampler][compute] real_bsz_temp_scaled: {real_bsz_temp_scaled}") - # print(f"[MTPSampler][compute] temperature: {temperature}") temp_temperature = paddle.where( real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) ).unsqueeze(1) - # print(f"[MTPSampler][compute] temp_temperature: {temp_temperature}") last_logits = last_logits / temp_temperature - # print(f"[MTPSampler][compute] last_logits: {last_logits}") last_logprobs = F.log_softmax(last_logits, axis=-1) - # print(f"[MTPSampler][compute] last_logits: {last_logits}") top_p_logprob = None top_p_token_mask = None @@ -690,7 +661,6 @@ def compute_logprobs( .repeat_interleave(share_inputs["batch_token_num"]) .unsqueeze(1) ) - # print(f"[MTPSampler][compute] real_token_top_p: {real_token_top_p}") top_p_normalized_logprobs = ( top_p_normalized_logprobs[:real_bsz] .astype("int32") @@ -699,17 +669,12 @@ def compute_logprobs( .astype("bool") .unsqueeze(1) ) - # print(f"[MTPSampler][compute] top_p_normalized_logprobs: {top_p_normalized_logprobs}") top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) - # print(f"[MTPSampler][compute] top_p_token_mask: {top_p_token_mask}") if top_p_token_mask.any(): probs = F.softmax(last_logits, axis=-1) - # print(f"[MTPSampler][compute] probs: {probs}") probs = top_p_normalize_probs_paddle(probs, real_token_top_p) - # print(f"[MTPSampler][compute] probs: {probs}") top_p_logprob = paddle.log(probs) - # print(f"[MTPSampler][compute] top_p_logprob: {top_p_logprob}") if top_p_logprob is not None: last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) return last_logprobs @@ -767,7 +732,6 @@ def forward_cuda( num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None and share_inputs["substep"] == 0: raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"], sampling_metadata) - print(f"[MTPSampler] raw_logprobs: {raw_logprobs}") logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, @@ -803,8 +767,6 @@ def forward_cuda( share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], ) - print(f"[MTPSampler] token_ids: {token_ids}") - print(f"[MTPSampler] total_token_num: {share_inputs['batch_token_num'].sum()}") logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) @@ -813,7 +775,5 @@ def forward_cuda( logprobs_tensors=logprobs_tensors, token_num_per_batch=share_inputs["batch_token_num"], ) - print(f"[MTPSampler] sampler_output: {sampler_output}") - print(f"[MTPSampler] next_tokens: {next_tokens}") return next_tokens, sampler_output diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index a4262a2999..e48260fc66 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -158,7 +158,6 @@ def process_sampling_results(self): get_output_ep, get_output_topk, speculate_get_output, - speculate_get_output_topk, ) rank_id = self.cfg.parallel_config.local_data_parallel_id @@ -166,24 +165,9 @@ def process_sampling_results(self): try: is_blocking = True if self.speculative_decoding: - if self.use_logprobs: - speculate_get_output_topk( - self.output_tokens, - self.output_scores, - self.output_ranks, - K, - rank_id, - is_blocking, - ) - print(f"[TokenProcessor] output_tokens: {self.output_tokens}") - print(f"[TokenProcessor] output_scores: {self.output_scores}") - print(f"[TokenProcessor] output_ranks: {self.output_ranks}") - if self.output_tokens[0, 0] == -2: - continue - else: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) - if self.output_tokens[0] == -2: - continue + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if self.output_tokens[0] == -2: + continue else: if self.use_logprobs: diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index bf2571b2c0..a02fbc05d3 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -626,7 +626,6 @@ def _propose(self, target_hidden_states): """ for substep in range(self.num_model_steps): if self.model_inputs["not_need_stop"]: - print(f"[MTPProposer] ******************** substep: {substep} ********************") self.model_inputs["substep"] = substep # Remove padding ( @@ -682,19 +681,12 @@ def _propose(self, target_hidden_states): previous_hidden_states=target_hidden_states, forward_meta=self.forward_meta, ) - print(f"[MTPProposer] model_output: {model_output}") if self.enable_logprob and substep == 0: first_token_hidden_states = paddle.empty( [self.max_num_seqs, self.model_config.hidden_size], dtype=model_output.dtype ) - print(f"[MTPProposer] cu_seqlens_q: {self.model_inputs['cu_seqlens_q']}") - print(f"[MTPProposer] seq_lens_this_time: {self.model_inputs['seq_lens_this_time']}") - print(f"[MTPProposer] seq_lens_encoder: {self.model_inputs['seq_lens_encoder']}") - print(f"[MTPProposer] seq_lens_decoder: {self.model_inputs['seq_lens_decoder']}") - print(f"[MTPProposer] output_cum_offsets: {self.model_inputs['output_cum_offsets']}") - print(f"[MTPProposer] output_padding_offset: {self.model_inputs['output_padding_offset']}") hidden_states = rebuild_padding( model_output, self.model_inputs["cu_seqlens_q"], @@ -706,16 +698,11 @@ def _propose(self, target_hidden_states): first_token_hidden_states if substep == 0 else None, self.enable_logprob if substep == 0 else False, ) - print(f"[MTPProposer] hidden_states: {hidden_states}") - print(f"[MTPProposer] first_token_hidden_states: {first_token_hidden_states}") # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) if self.enable_logprob and substep == 0: first_token_logits = self.model.compute_logits(first_token_hidden_states) - print(f"[MTPProposer] logits: {logits}") - print(f"[MTPProposer] first_token_logits: {first_token_logits}") - print(f"[MTPProposer] output_padding_offset: {self.model_inputs['output_padding_offset']}") draft_logits, batch_token_num, cu_batch_token_offset = speculate_get_logits( logits, @@ -727,9 +714,6 @@ def _propose(self, target_hidden_states): self.model_inputs["draft_logits"] = draft_logits self.model_inputs["batch_token_num"] = batch_token_num self.model_inputs["cu_batch_token_offset"] = cu_batch_token_offset - print(f"[MTPProposer] draft_logits: {draft_logits}") - print(f"[MTPProposer] batch_token_num: {batch_token_num}") - print(f"[MTPProposer] cu_batch_token_offset: {cu_batch_token_offset}") sampled_token_ids, sampler_output = self.sampler( logits, From aed79aec4f8d6dfeb8869e172ef947c4e401b9de Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Thu, 25 Sep 2025 21:06:10 +0800 Subject: [PATCH 03/10] fix --- custom_ops/gpu_ops/rebuild_padding.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index a80db03035..130b6d6062 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -143,7 +143,6 @@ std::vector rebuild_padding( int pack_num = elem_nums / PackSize; const int blocksize = 128; const int grid_size = (pack_num + blocksize - 1) / blocksize; - printf("elem_nums: %d\n", elem_nums); if (output_padding_offset) { RebuildAppendPaddingKernel From d5a3c5c933ad0cbbae70cff6bbf9c317c0681ebc Mon Sep 17 00:00:00 2001 From: sunlei1024 Date: Fri, 26 Sep 2025 13:07:48 +0800 Subject: [PATCH 04/10] feat: add draft_logprobs for Speculative Decode MTP --- fastdeploy/engine/request.py | 4 + fastdeploy/entrypoints/openai/protocol.py | 2 + fastdeploy/entrypoints/openai/serving_chat.py | 16 ++ .../entrypoints/openai/serving_completion.py | 16 ++ fastdeploy/output/token_processor.py | 112 +++++++++--- tests/output/test_process_batch_output.py | 167 ++++++++++++++++++ 6 files changed, 293 insertions(+), 24 deletions(-) create mode 100644 tests/output/test_process_batch_output.py diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 04a2276afb..0cade69734 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -287,6 +287,7 @@ class CompletionOutput: token_ids: list[int] logprob: Optional[float] = None top_logprobs: Optional[LogprobsLists] = None + draft_top_logprobs: Optional[LogprobsLists] = None logprobs: Optional[SampleLogprobs] = None draft_token_ids: list[int] = None text: Optional[str] = None @@ -412,6 +413,7 @@ def __init__( request_id: str, prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, + output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, metrics: Optional[RequestMetrics] = None, @@ -456,6 +458,7 @@ def __repr__(self) -> str: f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"output_type={self.output_type}, " f"outputs={self.outputs}, " f"finished={self.finished}, " f"num_cached_tokens={self.num_cached_tokens}, " @@ -476,6 +479,7 @@ def to_dict(self): "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, + "output_type": self.output_type, "outputs": None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(), "finished": self.finished, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index b74e0ffb46..f0805d697c 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -405,6 +405,7 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = None logprobs: Optional[int] = None + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False @@ -540,6 +541,7 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 125d785fe3..c1e189a366 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -295,10 +295,15 @@ async def chat_completion_stream_generator( output_top_logprobs = output["top_logprobs"] previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None + draft_logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) + if request.include_draft_logprobs: + draft_logprobs_res = self._create_chat_logprobs( + output_top_logprobs, request.logprobs, request.draft_top_logprobs + ) delta_message = DeltaMessage( reasoning_content="", @@ -326,6 +331,7 @@ async def chat_completion_stream_generator( index=0, delta=delta_message, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, arrival_time=arrival_time, ) if res["finished"]: @@ -461,11 +467,21 @@ async def chat_completion_full_generator( output = data["outputs"] output_top_logprobs = output["top_logprobs"] if output_top_logprobs is not None: + # logprobs logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) if logprobs_res and logprobs_res.content is not None: logprob_contents.extend(logprobs_res.content) + + # draf_logprobs + if request.include_draft_logprobs: + draft_logprobs_res = self._create_chat_logprobs( + output_top_logprobs, request.logprobs, request.draft_top_logprobs + ) + if draft_logprobs_res and draft_logprobs_res.content is not None: + draft_logprobs_res.extend(logprobs_res.content) + if data["finished"]: final_res = data task_is_finished = True diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 9b089d073d..e0d88d5444 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -212,6 +212,7 @@ async def completion_full_generator( valid_results = [dict()] * num_choices output_tokens = [0] * num_choices aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] + aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 @@ -239,11 +240,18 @@ async def completion_full_generator( output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0]) aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1]) aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) + # draft logprobs + if request.include_draft_logprobs: + aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) + aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) + aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) + aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) self.engine_client.data_processor.process_response_dict( @@ -390,10 +398,17 @@ async def completion_stream_generator( await self._echo_back_prompt(request, res, idx) output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] logprobs_res: Optional[CompletionLogprobs] = None + draft_logprobs_res: Optional[CompletionLogprobs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + # draft logprobs + if request.include_draft_logprobs: + draft_logprobs_res = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) output_tokens[idx] += 1 delta_message = CompletionResponseStreamChoice( index=idx, @@ -406,6 +421,7 @@ async def completion_stream_generator( reasoning_content="", arrival_time=arrival_time, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, ) if not res["finished"] and "delta_message" in output: delta_message_output = output["delta_message"] diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index e48260fc66..42a906f975 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -109,6 +109,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.executor = ThreadPoolExecutor(max_workers=1) self.prefill_result_status = dict() self._finalizer = weakref.finalize(self, self._cleanup_resources) + self._batch_result_buffer = None def _cleanup_resources(self): """Cleaning up shared memory resources""" @@ -165,7 +166,20 @@ def process_sampling_results(self): try: is_blocking = True if self.speculative_decoding: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if ( + self.cfg.parallel_config.enable_expert_parallel + and self.cfg.parallel_config.data_parallel_size > 1 + ): + if self.use_logprobs: + # TODO speculate_get_output_with_topk + pass + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + elif self.use_logprobs: + # TODO speculate_get_output_with_topk + pass + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) if self.output_tokens[0] == -2: continue @@ -213,7 +227,7 @@ def process_metrics(): self.executor.submit(process_metrics) - def postprocess(self, batch_result): + def postprocess(self, batch_result, mtype=3): """ single post-processing function @@ -221,7 +235,21 @@ def postprocess(self, batch_result): batch_result (list): batch results """ try: - self.cached_generated_tokens.put_results(batch_result) + if self.cfg.speculative_config.method and self.cfg.use_logprobs: + if mtype == 3: # target + self._batch_result_buffer = batch_result + elif mtype == 4: # draft + target_batch_result = [] + draft_batch_result = batch_result + for target, decode in zip(self._batch_result_buffer, draft_batch_result): + target["outputs"]["draft_top_logprobs"] = decode["outputs"]["draft_top_logprobs"] + target_batch_result.append(target) + self._batch_result_buffer = None + self.cached_generated_tokens.put_results(target_batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") @@ -302,9 +330,19 @@ def _process_batch_output(self): tokens = self.output_tokens.numpy() scores = None ranks = None + # target:3, draft:4 + mtype = 3 if self.cfg.speculative_config.method: - batch = self.output_tokens[1] - accept_num = tokens[2 : batch + 2] + if self.use_logprobs: + mtype = self.output_tokens[1, 0] + batch = self.output_tokens[2, 0] + accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] + tokens = tokens[3 + batch : 3 + batch + batch * (K + 1) * MAX_DRAFT_TOKENS].reshape( + [batch, K + 1, MAX_DRAFT_TOKENS] + ) + else: + batch = self.output_tokens[1] + accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) elif self.use_logprobs: batch = self.output_tokens[1, 0] @@ -332,19 +370,24 @@ def _process_batch_output(self): task_id = task.request_id if self.cfg.speculative_config.method: - token_ids = tokens[ - 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS : 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS - + accept_num[i] - ].tolist() - if len(token_ids) == 0 or token_ids[-1] <= 0: - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - if task_id in self.resource_manager.to_be_rescheduled_request_id_set: - self.resource_manager.reschedule_preempt_task(task_id) - continue + if accept_num[i] == -3: + recovery_stop = True + if recovery_stop: + llm_logger.info(f"recovery stop signal found at task {task_id}") + token_ids = [RECOVERY_STOP_SIGNAL] + elif self.use_logprobs: + token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + else: + token_ids = tokens[ + 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS + + accept_num[i] + ].tolist() + if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] @@ -387,6 +430,7 @@ def _process_batch_output(self): self._record_metrics(task, current_time, token_ids) result = RequestOutput( request_id=task_id, + output_type=mtype, outputs=CompletionOutput( index=i, send_idx=self.tokens_counter[task_id], @@ -412,16 +456,36 @@ def _process_batch_output(self): result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: + # TODO 投机解码场景兼容支持 result.outputs.logprob = float(scores[i, 0]) # Construct top_logprobs topk_token_ids = tokens[i, :].tolist() topk_logprobs = scores[i, :].tolist() sampled_rank = ranks[i].item() - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) + + if mtype == 3: # top_logprobs + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + elif mtype == 4: # draft_top_logprobs + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) + if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: @@ -442,7 +506,7 @@ def _process_batch_output(self): if not is_prefill or self.cfg.scheduler_config.name == "splitwise": batch_result.append(result) - self.postprocess(batch_result) + self.postprocess(batch_result, mtype) def _record_metrics(self, task, current_time, token_ids): """Record all metrics for a task""" diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py new file mode 100644 index 0000000000..0d487c00f1 --- /dev/null +++ b/tests/output/test_process_batch_output.py @@ -0,0 +1,167 @@ +import time +import unittest +from unittest.mock import Mock + +import paddle + +from fastdeploy.output.token_processor import TokenProcessor + +paddle.set_device("cpu") + + +# Mock classes and constants needed for the test +class MockConfig: + class ParallelConfig: + local_data_parallel_id = 0 + + class SpeculativeConfig: + method = None + + class ModelConfig: + enable_logprob = False + + class SchedulerConfig: + name = "default" + + parallel_config = ParallelConfig() + speculative_config = SpeculativeConfig() + model_config = ModelConfig() + scheduler_config = SchedulerConfig() + + +class MockTask: + def __init__(self): + self.request_id = "test_request_1" + self.arrival_time = time.time() + self.inference_start_time = time.time() + self.schedule_start_time = time.time() + self.preprocess_end_time = time.time() - 0.1 + self.preprocess_start_time = time.time() - 0.2 + self.eos_token_ids = [2] + self.output_token_ids = [] + self.messages = "Test prompt" + self.num_cached_tokens = 0 + self.disaggregate_info = None + self.prefill_chunk_info = None + self.prefill_chunk_num = 0 + + +class MockResourceManager: + def __init__(self): + self.stop_flags = [False] + self.tasks_list = [MockTask()] + self.to_be_rescheduled_request_id_set = set() + + def info(self): + return "Mock resource manager info" + + def reschedule_preempt_task(self, task_id): + pass + + +# Constants +RECOVERY_STOP_SIGNAL = -3 +MAX_BSZ = 512 +K = 20 +MAX_DRAFT_TOKENS = 6 +SPECULATE_MAX_BSZ = 256 + + +class TestTokenProcessorProcessBatchOutput(unittest.TestCase): + + def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): + """Helper method to setup TokenProcessor with different configurations""" + cfg = MockConfig() + cfg.speculative_config.method = "mtp" if speculative_decoding else None + cfg.model_config.enable_logprob = use_logprobs + + processor = TokenProcessor.__new__(TokenProcessor) + processor.cfg = cfg + processor.cached_generated_tokens = [] + processor.engine_worker_queue = Mock() + processor.split_connector = Mock() + processor.resource_manager = MockResourceManager() + processor.tokens_counter = {} + processor.total_step = 0 + processor.number_of_output_tokens = 0 + processor.prefill_result_status = {} + processor.executor = Mock() + + if speculative_decoding: + if use_logprobs: + processor.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], + fill_value=2, + dtype="int64", + ) + processor.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], + fill_value=0.0, + dtype="float32", + ) + processor.output_ranks = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS], + fill_value=0, + dtype="int64", + ) + else: + processor.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) + elif use_logprobs: + processor.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") + processor.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") + processor.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") + else: + processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") + + return processor + + def test_speculative_decoding_use_logprobs(self): + """Test basic speculative decoding scenario""" + processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True) + print(f"{processor}") + + # batch_size = 1 + # max_draft_tokens = MAX_DRAFT_TOKENS + + # # Setup speculative decoding output format + # output_tokens_np = np.full( + # SPECULATE_MAX_BSZ * max_draft_tokens + SPECULATE_MAX_BSZ + 10, + # 2, + # dtype=np.int64, + # ) + # output_tokens_np[1] = batch_size # batch size + # output_tokens_np[2:2 + batch_size] = [3] # accept numbers (3 accepted tokens) + + # # Setup draft tokens + # start_idx = 2 + SPECULATE_MAX_BSZ + # for i in range(batch_size): + # draft_tokens = np.arange(100, 100 + max_draft_tokens) + # output_tokens_np[ + # start_idx + i * max_draft_tokens:start_idx + (i + 1) * max_draft_tokens + # ] = draft_tokens + + # processor.output_tokens = paddle.to_tensor(output_tokens_np) + # processor.tokens_counter = {"test_request_1": 0} + # processor.postprocess = Mock() + + # # Mock speculative decoding metrics recording + # processor._record_speculative_decoding_mertics = Mock() + # processor._compute_speculative_status = Mock() + + # with patch.object(processor.resource_manager, "stop_flags", [False]): + # with patch.object(processor.resource_manager.tasks_list[0], "eos_token_ids", [2]): + # processor._process_batch_output() + + # self.assertTrue(processor._record_speculative_decoding_mertics.called) + # results = processor.postprocess.call_args[0][0] + # self.assertEqual(len(results), 1) + # # Should have 3 tokens (based on accept_num) + # self.assertEqual(len(results[0].outputs.token_ids), 3) + + +if __name__ == "__main__": + unittest.main(verbosity=2, buffer=False) From fd4ed68a5597b6cc1d383b24ebce9e42f62cebd5 Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Sun, 28 Sep 2025 11:18:44 +0800 Subject: [PATCH 05/10] Revert "feat: add draft_logprobs for Speculative Decode MTP" This reverts commit d5a3c5c933ad0cbbae70cff6bbf9c317c0681ebc. --- fastdeploy/engine/request.py | 4 - fastdeploy/entrypoints/openai/protocol.py | 2 - fastdeploy/entrypoints/openai/serving_chat.py | 16 -- .../entrypoints/openai/serving_completion.py | 16 -- fastdeploy/output/token_processor.py | 112 +++--------- tests/output/test_process_batch_output.py | 167 ------------------ 6 files changed, 24 insertions(+), 293 deletions(-) delete mode 100644 tests/output/test_process_batch_output.py diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 0cade69734..04a2276afb 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -287,7 +287,6 @@ class CompletionOutput: token_ids: list[int] logprob: Optional[float] = None top_logprobs: Optional[LogprobsLists] = None - draft_top_logprobs: Optional[LogprobsLists] = None logprobs: Optional[SampleLogprobs] = None draft_token_ids: list[int] = None text: Optional[str] = None @@ -413,7 +412,6 @@ def __init__( request_id: str, prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, - output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, metrics: Optional[RequestMetrics] = None, @@ -458,7 +456,6 @@ def __repr__(self) -> str: f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " - f"output_type={self.output_type}, " f"outputs={self.outputs}, " f"finished={self.finished}, " f"num_cached_tokens={self.num_cached_tokens}, " @@ -479,7 +476,6 @@ def to_dict(self): "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, - "output_type": self.output_type, "outputs": None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(), "finished": self.finished, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index f0805d697c..b74e0ffb46 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -405,7 +405,6 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = None logprobs: Optional[int] = None - include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False @@ -541,7 +540,6 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 - include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index c1e189a366..125d785fe3 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -295,15 +295,10 @@ async def chat_completion_stream_generator( output_top_logprobs = output["top_logprobs"] previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None - draft_logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) - if request.include_draft_logprobs: - draft_logprobs_res = self._create_chat_logprobs( - output_top_logprobs, request.logprobs, request.draft_top_logprobs - ) delta_message = DeltaMessage( reasoning_content="", @@ -331,7 +326,6 @@ async def chat_completion_stream_generator( index=0, delta=delta_message, logprobs=logprobs_res, - draft_logprobs=draft_logprobs_res, arrival_time=arrival_time, ) if res["finished"]: @@ -467,21 +461,11 @@ async def chat_completion_full_generator( output = data["outputs"] output_top_logprobs = output["top_logprobs"] if output_top_logprobs is not None: - # logprobs logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) if logprobs_res and logprobs_res.content is not None: logprob_contents.extend(logprobs_res.content) - - # draf_logprobs - if request.include_draft_logprobs: - draft_logprobs_res = self._create_chat_logprobs( - output_top_logprobs, request.logprobs, request.draft_top_logprobs - ) - if draft_logprobs_res and draft_logprobs_res.content is not None: - draft_logprobs_res.extend(logprobs_res.content) - if data["finished"]: final_res = data task_is_finished = True diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index e0d88d5444..9b089d073d 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -212,7 +212,6 @@ async def completion_full_generator( valid_results = [dict()] * num_choices output_tokens = [0] * num_choices aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] - aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 @@ -240,18 +239,11 @@ async def completion_full_generator( output = data["outputs"] output_top_logprobs = output["top_logprobs"] - output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0]) aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1]) aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) - # draft logprobs - if request.include_draft_logprobs: - aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) - aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) - aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) - aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) self.engine_client.data_processor.process_response_dict( @@ -398,17 +390,10 @@ async def completion_stream_generator( await self._echo_back_prompt(request, res, idx) output = res["outputs"] output_top_logprobs = output["top_logprobs"] - output_draft_top_logprobs = output["draft_top_logprobs"] logprobs_res: Optional[CompletionLogprobs] = None - draft_logprobs_res: Optional[CompletionLogprobs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) - # draft logprobs - if request.include_draft_logprobs: - draft_logprobs_res = self._create_completion_logprobs( - output_draft_top_logprobs, request.logprobs, 0 - ) output_tokens[idx] += 1 delta_message = CompletionResponseStreamChoice( index=idx, @@ -421,7 +406,6 @@ async def completion_stream_generator( reasoning_content="", arrival_time=arrival_time, logprobs=logprobs_res, - draft_logprobs=draft_logprobs_res, ) if not res["finished"] and "delta_message" in output: delta_message_output = output["delta_message"] diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 42a906f975..e48260fc66 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -109,7 +109,6 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.executor = ThreadPoolExecutor(max_workers=1) self.prefill_result_status = dict() self._finalizer = weakref.finalize(self, self._cleanup_resources) - self._batch_result_buffer = None def _cleanup_resources(self): """Cleaning up shared memory resources""" @@ -166,20 +165,7 @@ def process_sampling_results(self): try: is_blocking = True if self.speculative_decoding: - if ( - self.cfg.parallel_config.enable_expert_parallel - and self.cfg.parallel_config.data_parallel_size > 1 - ): - if self.use_logprobs: - # TODO speculate_get_output_with_topk - pass - else: - speculate_get_output(self.output_tokens, rank_id, is_blocking, True) - elif self.use_logprobs: - # TODO speculate_get_output_with_topk - pass - else: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) if self.output_tokens[0] == -2: continue @@ -227,7 +213,7 @@ def process_metrics(): self.executor.submit(process_metrics) - def postprocess(self, batch_result, mtype=3): + def postprocess(self, batch_result): """ single post-processing function @@ -235,21 +221,7 @@ def postprocess(self, batch_result, mtype=3): batch_result (list): batch results """ try: - if self.cfg.speculative_config.method and self.cfg.use_logprobs: - if mtype == 3: # target - self._batch_result_buffer = batch_result - elif mtype == 4: # draft - target_batch_result = [] - draft_batch_result = batch_result - for target, decode in zip(self._batch_result_buffer, draft_batch_result): - target["outputs"]["draft_top_logprobs"] = decode["outputs"]["draft_top_logprobs"] - target_batch_result.append(target) - self._batch_result_buffer = None - self.cached_generated_tokens.put_results(target_batch_result) - else: - self.cached_generated_tokens.put_results(batch_result) - else: - self.cached_generated_tokens.put_results(batch_result) + self.cached_generated_tokens.put_results(batch_result) except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") @@ -330,19 +302,9 @@ def _process_batch_output(self): tokens = self.output_tokens.numpy() scores = None ranks = None - # target:3, draft:4 - mtype = 3 if self.cfg.speculative_config.method: - if self.use_logprobs: - mtype = self.output_tokens[1, 0] - batch = self.output_tokens[2, 0] - accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] - tokens = tokens[3 + batch : 3 + batch + batch * (K + 1) * MAX_DRAFT_TOKENS].reshape( - [batch, K + 1, MAX_DRAFT_TOKENS] - ) - else: - batch = self.output_tokens[1] - accept_num = tokens[2 : batch + 2] + batch = self.output_tokens[1] + accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) elif self.use_logprobs: batch = self.output_tokens[1, 0] @@ -370,24 +332,19 @@ def _process_batch_output(self): task_id = task.request_id if self.cfg.speculative_config.method: - if accept_num[i] == -3: - recovery_stop = True - if recovery_stop: - llm_logger.info(f"recovery stop signal found at task {task_id}") - token_ids = [RECOVERY_STOP_SIGNAL] - elif self.use_logprobs: - token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] - else: - token_ids = tokens[ - 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS : 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS - + accept_num[i] - ].tolist() - if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): - continue + token_ids = tokens[ + 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS + + accept_num[i] + ].tolist() + if len(token_ids) == 0 or token_ids[-1] <= 0: + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if task_id in self.resource_manager.to_be_rescheduled_request_id_set: + self.resource_manager.reschedule_preempt_task(task_id) + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] @@ -430,7 +387,6 @@ def _process_batch_output(self): self._record_metrics(task, current_time, token_ids) result = RequestOutput( request_id=task_id, - output_type=mtype, outputs=CompletionOutput( index=i, send_idx=self.tokens_counter[task_id], @@ -456,36 +412,16 @@ def _process_batch_output(self): result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: - # TODO 投机解码场景兼容支持 result.outputs.logprob = float(scores[i, 0]) # Construct top_logprobs topk_token_ids = tokens[i, :].tolist() topk_logprobs = scores[i, :].tolist() sampled_rank = ranks[i].item() - - if mtype == 3: # top_logprobs - if result.outputs.top_logprobs is None: - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - else: - result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) - result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) - result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) - elif mtype == 4: # draft_top_logprobs - if result.outputs.draft_top_logprobs is None: - result.outputs.draft_top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - else: - result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) - result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) - result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) - + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: @@ -506,7 +442,7 @@ def _process_batch_output(self): if not is_prefill or self.cfg.scheduler_config.name == "splitwise": batch_result.append(result) - self.postprocess(batch_result, mtype) + self.postprocess(batch_result) def _record_metrics(self, task, current_time, token_ids): """Record all metrics for a task""" diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py deleted file mode 100644 index 0d487c00f1..0000000000 --- a/tests/output/test_process_batch_output.py +++ /dev/null @@ -1,167 +0,0 @@ -import time -import unittest -from unittest.mock import Mock - -import paddle - -from fastdeploy.output.token_processor import TokenProcessor - -paddle.set_device("cpu") - - -# Mock classes and constants needed for the test -class MockConfig: - class ParallelConfig: - local_data_parallel_id = 0 - - class SpeculativeConfig: - method = None - - class ModelConfig: - enable_logprob = False - - class SchedulerConfig: - name = "default" - - parallel_config = ParallelConfig() - speculative_config = SpeculativeConfig() - model_config = ModelConfig() - scheduler_config = SchedulerConfig() - - -class MockTask: - def __init__(self): - self.request_id = "test_request_1" - self.arrival_time = time.time() - self.inference_start_time = time.time() - self.schedule_start_time = time.time() - self.preprocess_end_time = time.time() - 0.1 - self.preprocess_start_time = time.time() - 0.2 - self.eos_token_ids = [2] - self.output_token_ids = [] - self.messages = "Test prompt" - self.num_cached_tokens = 0 - self.disaggregate_info = None - self.prefill_chunk_info = None - self.prefill_chunk_num = 0 - - -class MockResourceManager: - def __init__(self): - self.stop_flags = [False] - self.tasks_list = [MockTask()] - self.to_be_rescheduled_request_id_set = set() - - def info(self): - return "Mock resource manager info" - - def reschedule_preempt_task(self, task_id): - pass - - -# Constants -RECOVERY_STOP_SIGNAL = -3 -MAX_BSZ = 512 -K = 20 -MAX_DRAFT_TOKENS = 6 -SPECULATE_MAX_BSZ = 256 - - -class TestTokenProcessorProcessBatchOutput(unittest.TestCase): - - def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): - """Helper method to setup TokenProcessor with different configurations""" - cfg = MockConfig() - cfg.speculative_config.method = "mtp" if speculative_decoding else None - cfg.model_config.enable_logprob = use_logprobs - - processor = TokenProcessor.__new__(TokenProcessor) - processor.cfg = cfg - processor.cached_generated_tokens = [] - processor.engine_worker_queue = Mock() - processor.split_connector = Mock() - processor.resource_manager = MockResourceManager() - processor.tokens_counter = {} - processor.total_step = 0 - processor.number_of_output_tokens = 0 - processor.prefill_result_status = {} - processor.executor = Mock() - - if speculative_decoding: - if use_logprobs: - processor.output_tokens = paddle.full( - shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], - fill_value=2, - dtype="int64", - ) - processor.output_scores = paddle.full( - shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], - fill_value=0.0, - dtype="float32", - ) - processor.output_ranks = paddle.full( - shape=[MAX_BSZ * MAX_DRAFT_TOKENS], - fill_value=0, - dtype="int64", - ) - else: - processor.output_tokens = paddle.full( - shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], - fill_value=2, - dtype="int64", - ) - elif use_logprobs: - processor.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") - processor.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") - processor.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") - else: - processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") - - return processor - - def test_speculative_decoding_use_logprobs(self): - """Test basic speculative decoding scenario""" - processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True) - print(f"{processor}") - - # batch_size = 1 - # max_draft_tokens = MAX_DRAFT_TOKENS - - # # Setup speculative decoding output format - # output_tokens_np = np.full( - # SPECULATE_MAX_BSZ * max_draft_tokens + SPECULATE_MAX_BSZ + 10, - # 2, - # dtype=np.int64, - # ) - # output_tokens_np[1] = batch_size # batch size - # output_tokens_np[2:2 + batch_size] = [3] # accept numbers (3 accepted tokens) - - # # Setup draft tokens - # start_idx = 2 + SPECULATE_MAX_BSZ - # for i in range(batch_size): - # draft_tokens = np.arange(100, 100 + max_draft_tokens) - # output_tokens_np[ - # start_idx + i * max_draft_tokens:start_idx + (i + 1) * max_draft_tokens - # ] = draft_tokens - - # processor.output_tokens = paddle.to_tensor(output_tokens_np) - # processor.tokens_counter = {"test_request_1": 0} - # processor.postprocess = Mock() - - # # Mock speculative decoding metrics recording - # processor._record_speculative_decoding_mertics = Mock() - # processor._compute_speculative_status = Mock() - - # with patch.object(processor.resource_manager, "stop_flags", [False]): - # with patch.object(processor.resource_manager.tasks_list[0], "eos_token_ids", [2]): - # processor._process_batch_output() - - # self.assertTrue(processor._record_speculative_decoding_mertics.called) - # results = processor.postprocess.call_args[0][0] - # self.assertEqual(len(results), 1) - # # Should have 3 tokens (based on accept_num) - # self.assertEqual(len(results[0].outputs.token_ids), 3) - - -if __name__ == "__main__": - unittest.main(verbosity=2, buffer=False) From 966e6e39ed2d5cb3a9ce64fc2a98cf7a978ea53e Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Sun, 28 Sep 2025 11:27:00 +0800 Subject: [PATCH 06/10] fix --- custom_ops/gpu_ops/cpp_extensions.cc | 11 +---- .../speculate_get_output_with_topk.cc | 46 +++++++------------ .../speculate_save_output_with_topk.cc | 24 ++++------ fastdeploy/worker/output.py | 2 +- 4 files changed, 27 insertions(+), 56 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 2409b9553a..2c1c4580e3 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -923,16 +923,9 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor &sampled_token_ids, const paddle::Tensor &token_num_per_batch, const paddle::Tensor &cu_batch_token_offset, const paddle::Tensor ¬_need_stop, - int mtype, + int message_flag, int64_t rank_id); -void SpeculateGetOutMmsgTopK(const paddle::Tensor &output_tokens, - const paddle::Tensor &output_scores, - const paddle::Tensor &output_ranks, - int real_k, - int64_t rank_id, - bool wait_flag); - PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -1327,6 +1320,4 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function"); m.def("speculate_save_output_topk", &SpeculateSaveOutMmsgTopK, "speculate_save_output_topk function"); - - m.def("speculate_get_output_topk", &SpeculateGetOutMmsgTopK, "speculate_get_output_topk function"); } diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc index 1dc7e416cc..f7ca8733d1 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -26,7 +26,6 @@ #define MAX_BSZ 512 #define K 20 #define MAX_DRAFT_TOKEN_NUM 6 -// #define SPECULATE_GET_WITH_OUTPUT_DEBUG struct batch_msgdata { int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; @@ -35,7 +34,8 @@ struct batch_msgdata { }; struct msgdata { - int meta[3 + MAX_BSZ]; // stop_flag, mtype, bsz, batch_token_nums + long mtype; + int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums batch_msgdata mtext[MAX_BSZ]; }; @@ -45,7 +45,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, int real_k, int64_t rank_id, bool wait_flag) { - static struct msgdata msg_rcv; + struct msgdata msg_rcv; int msg_queue_id = 1; if (const char* inference_msg_queue_id_env_p = @@ -76,42 +76,27 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, int ret = -1; if (!wait_flag) { ret = msgrcv( - msgid, - &msg_rcv, - (3 + MAX_BSZ) * sizeof(int) + - MAX_BSZ * ((MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(int) + - (MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(float) + - MAX_DRAFT_TOKEN_NUM * sizeof(int)), - 0, - IPC_NOWAIT); + msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, IPC_NOWAIT); } else { - ret = msgrcv( - msgid, - &msg_rcv, - (3 + MAX_BSZ) * sizeof(int) + - MAX_BSZ * ((MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(int) + - (MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(float) + - MAX_DRAFT_TOKEN_NUM * sizeof(int)), - 0, - 0); + ret = msgrcv(msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, 0); } if (ret == -1) { // read none - output_tokens_data[0] = -2; // stop_flag - output_tokens_data[1] = msg_rcv.meta[1]; // mtype, Target: 3, Draft: 4 - output_tokens_data[2] = 0; // bsz + output_tokens_data[0] = -2; // stop_flag + output_tokens_data[1] = 0; // message_flag, Target: 3, Draft: 4 + output_tokens_data[2] = 0; // bsz return; } int bsz = msg_rcv.meta[1]; - output_tokens_data[0] = msg_rcv.meta[0]; - output_tokens_data[1] = msg_rcv.meta[1]; - output_tokens_data[2] = msg_rcv.meta[2]; + output_tokens_data[0] = (int64_t)msg_rcv.meta[0]; + output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; + output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; int output_tokens_offset = 3 + MAX_BSZ; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_rcv.meta[3 + i]; - output_tokens_data[3 + i] = cur_token_num; // batch_token_nums + output_tokens_data[3 + i] = (int64_t)cur_token_num; // batch_token_nums auto* cur_output_token = output_tokens_data + output_tokens_offset + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); @@ -121,14 +106,15 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, for (int j = 0; j < cur_token_num; j++) { for (int k = 0; k < real_k + 1; k++) { cur_output_token[j * (K + 1) + k] = - cur_batch_msg_rcv->tokens[k]; + (int64_t)cur_batch_msg_rcv->tokens[j * (K + 1) + k]; cur_output_score[j * (K + 1) + k] = - cur_batch_msg_rcv->scores[k]; + cur_batch_msg_rcv->scores[j * (K + 1) + k]; } output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] = - cur_batch_msg_rcv->ranks[j]; + (int64_t)cur_batch_msg_rcv->ranks[j]; } } + return; } PD_BUILD_STATIC_OP(speculate_get_output_topk) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index b345893c98..78eb6c1d48 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -26,7 +26,6 @@ #define MAX_BSZ 512 #define K 20 #define MAX_DRAFT_TOKEN_NUM 6 -// #define SPECULATE_SAVE_WITH_OUTPUT_DEBUG struct batch_msgdata { int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; @@ -35,7 +34,8 @@ struct batch_msgdata { }; struct msgdata { - int meta[3 + MAX_BSZ]; // stop_flag, mtype, bsz, batch_token_nums + long mtype; + int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums batch_msgdata mtext[MAX_BSZ]; }; @@ -46,7 +46,7 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, const paddle::Tensor& token_num_per_batch, const paddle::Tensor& cu_batch_token_offset, const paddle::Tensor& not_need_stop, - int mtype, // Target: 3, Draft: 4 + int message_flag, // Target: 3, Draft: 4 int64_t rank_id) { if (rank_id > 0) { return; @@ -118,11 +118,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, std::cout << "save_output_key: " << key << std::endl; std::cout << "save msgid: " << msgid << std::endl; #endif - + msg_sed.mtype = 1; msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; - msg_sed.meta[1] = mtype; + msg_sed.meta[1] = message_flag; int bsz = token_num_per_batch.shape()[0]; msg_sed.meta[2] = bsz; int max_num_logprobs = logprob_token_ids.shape()[1]; @@ -158,8 +158,8 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, #ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG std::cout << "msg data: " << std::endl; std::cout << "stop_flag: " << msg_sed.meta[0] - << ", mtype: " << msg_sed.meta[1] << ", bsz: " << msg_sed.meta[2] - << std::endl; + << ", message_flag: " << msg_sed.meta[1] + << ", bsz: " << msg_sed.meta[2] << std::endl; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_sed.meta[3 + i]; auto* cur_batch_msg_sed = &msg_sed.mtext[i]; @@ -183,13 +183,7 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, } std::cout << std::endl; #endif - if ((msgsnd(msgid, - &msg_sed, - (3 + MAX_BSZ) * sizeof(int) + - MAX_BSZ * ((MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(int) + - (MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(float) + - MAX_DRAFT_TOKEN_NUM * sizeof(int)), - 0)) == -1) { + if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) { printf("full msg buffer\n"); } } @@ -204,5 +198,5 @@ PD_BUILD_STATIC_OP(speculate_save_output_topk) "cu_batch_token_offset", "not_need_stop", }) - .Attrs({"mtype: int", "rank_id: int64_t"}) + .Attrs({"message_flag: int", "rank_id: int64_t"}) .SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK)); diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 9bc35d4ee0..a4176771d0 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -106,7 +106,7 @@ class SamplerOutput: # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding. sampled_token_ids: paddle.Tensor logprobs_tensors: Optional[LogprobsTensors] - token_num_per_batch: Optional[paddle.Tensor] + token_num_per_batch: Optional[paddle.Tensor] = None @dataclass From e6954395e18e26d69e13c2bea414d87fbc0f68a4 Mon Sep 17 00:00:00 2001 From: sunlei1024 Date: Fri, 26 Sep 2025 13:07:48 +0800 Subject: [PATCH 07/10] feat: add draft_logprobs for Speculative Decode MTP --- fastdeploy/engine/request.py | 4 + fastdeploy/entrypoints/openai/protocol.py | 2 + fastdeploy/entrypoints/openai/serving_chat.py | 16 ++ .../entrypoints/openai/serving_completion.py | 16 ++ fastdeploy/output/token_processor.py | 112 +++++++++--- tests/output/test_process_batch_output.py | 167 ++++++++++++++++++ 6 files changed, 293 insertions(+), 24 deletions(-) create mode 100644 tests/output/test_process_batch_output.py diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 04a2276afb..0cade69734 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -287,6 +287,7 @@ class CompletionOutput: token_ids: list[int] logprob: Optional[float] = None top_logprobs: Optional[LogprobsLists] = None + draft_top_logprobs: Optional[LogprobsLists] = None logprobs: Optional[SampleLogprobs] = None draft_token_ids: list[int] = None text: Optional[str] = None @@ -412,6 +413,7 @@ def __init__( request_id: str, prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, + output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, metrics: Optional[RequestMetrics] = None, @@ -456,6 +458,7 @@ def __repr__(self) -> str: f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"output_type={self.output_type}, " f"outputs={self.outputs}, " f"finished={self.finished}, " f"num_cached_tokens={self.num_cached_tokens}, " @@ -476,6 +479,7 @@ def to_dict(self): "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, + "output_type": self.output_type, "outputs": None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(), "finished": self.finished, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index b74e0ffb46..f0805d697c 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -405,6 +405,7 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = None logprobs: Optional[int] = None + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False @@ -540,6 +541,7 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 125d785fe3..c1e189a366 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -295,10 +295,15 @@ async def chat_completion_stream_generator( output_top_logprobs = output["top_logprobs"] previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None + draft_logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) + if request.include_draft_logprobs: + draft_logprobs_res = self._create_chat_logprobs( + output_top_logprobs, request.logprobs, request.draft_top_logprobs + ) delta_message = DeltaMessage( reasoning_content="", @@ -326,6 +331,7 @@ async def chat_completion_stream_generator( index=0, delta=delta_message, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, arrival_time=arrival_time, ) if res["finished"]: @@ -461,11 +467,21 @@ async def chat_completion_full_generator( output = data["outputs"] output_top_logprobs = output["top_logprobs"] if output_top_logprobs is not None: + # logprobs logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) if logprobs_res and logprobs_res.content is not None: logprob_contents.extend(logprobs_res.content) + + # draf_logprobs + if request.include_draft_logprobs: + draft_logprobs_res = self._create_chat_logprobs( + output_top_logprobs, request.logprobs, request.draft_top_logprobs + ) + if draft_logprobs_res and draft_logprobs_res.content is not None: + draft_logprobs_res.extend(logprobs_res.content) + if data["finished"]: final_res = data task_is_finished = True diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 9b089d073d..e0d88d5444 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -212,6 +212,7 @@ async def completion_full_generator( valid_results = [dict()] * num_choices output_tokens = [0] * num_choices aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] + aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 @@ -239,11 +240,18 @@ async def completion_full_generator( output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0]) aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1]) aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) + # draft logprobs + if request.include_draft_logprobs: + aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) + aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) + aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) + aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) self.engine_client.data_processor.process_response_dict( @@ -390,10 +398,17 @@ async def completion_stream_generator( await self._echo_back_prompt(request, res, idx) output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] logprobs_res: Optional[CompletionLogprobs] = None + draft_logprobs_res: Optional[CompletionLogprobs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + # draft logprobs + if request.include_draft_logprobs: + draft_logprobs_res = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) output_tokens[idx] += 1 delta_message = CompletionResponseStreamChoice( index=idx, @@ -406,6 +421,7 @@ async def completion_stream_generator( reasoning_content="", arrival_time=arrival_time, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, ) if not res["finished"] and "delta_message" in output: delta_message_output = output["delta_message"] diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index e48260fc66..42a906f975 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -109,6 +109,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.executor = ThreadPoolExecutor(max_workers=1) self.prefill_result_status = dict() self._finalizer = weakref.finalize(self, self._cleanup_resources) + self._batch_result_buffer = None def _cleanup_resources(self): """Cleaning up shared memory resources""" @@ -165,7 +166,20 @@ def process_sampling_results(self): try: is_blocking = True if self.speculative_decoding: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if ( + self.cfg.parallel_config.enable_expert_parallel + and self.cfg.parallel_config.data_parallel_size > 1 + ): + if self.use_logprobs: + # TODO speculate_get_output_with_topk + pass + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + elif self.use_logprobs: + # TODO speculate_get_output_with_topk + pass + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) if self.output_tokens[0] == -2: continue @@ -213,7 +227,7 @@ def process_metrics(): self.executor.submit(process_metrics) - def postprocess(self, batch_result): + def postprocess(self, batch_result, mtype=3): """ single post-processing function @@ -221,7 +235,21 @@ def postprocess(self, batch_result): batch_result (list): batch results """ try: - self.cached_generated_tokens.put_results(batch_result) + if self.cfg.speculative_config.method and self.cfg.use_logprobs: + if mtype == 3: # target + self._batch_result_buffer = batch_result + elif mtype == 4: # draft + target_batch_result = [] + draft_batch_result = batch_result + for target, decode in zip(self._batch_result_buffer, draft_batch_result): + target["outputs"]["draft_top_logprobs"] = decode["outputs"]["draft_top_logprobs"] + target_batch_result.append(target) + self._batch_result_buffer = None + self.cached_generated_tokens.put_results(target_batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") @@ -302,9 +330,19 @@ def _process_batch_output(self): tokens = self.output_tokens.numpy() scores = None ranks = None + # target:3, draft:4 + mtype = 3 if self.cfg.speculative_config.method: - batch = self.output_tokens[1] - accept_num = tokens[2 : batch + 2] + if self.use_logprobs: + mtype = self.output_tokens[1, 0] + batch = self.output_tokens[2, 0] + accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] + tokens = tokens[3 + batch : 3 + batch + batch * (K + 1) * MAX_DRAFT_TOKENS].reshape( + [batch, K + 1, MAX_DRAFT_TOKENS] + ) + else: + batch = self.output_tokens[1] + accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) elif self.use_logprobs: batch = self.output_tokens[1, 0] @@ -332,19 +370,24 @@ def _process_batch_output(self): task_id = task.request_id if self.cfg.speculative_config.method: - token_ids = tokens[ - 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS : 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS - + accept_num[i] - ].tolist() - if len(token_ids) == 0 or token_ids[-1] <= 0: - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - if task_id in self.resource_manager.to_be_rescheduled_request_id_set: - self.resource_manager.reschedule_preempt_task(task_id) - continue + if accept_num[i] == -3: + recovery_stop = True + if recovery_stop: + llm_logger.info(f"recovery stop signal found at task {task_id}") + token_ids = [RECOVERY_STOP_SIGNAL] + elif self.use_logprobs: + token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + else: + token_ids = tokens[ + 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS + + accept_num[i] + ].tolist() + if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] @@ -387,6 +430,7 @@ def _process_batch_output(self): self._record_metrics(task, current_time, token_ids) result = RequestOutput( request_id=task_id, + output_type=mtype, outputs=CompletionOutput( index=i, send_idx=self.tokens_counter[task_id], @@ -412,16 +456,36 @@ def _process_batch_output(self): result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: + # TODO 投机解码场景兼容支持 result.outputs.logprob = float(scores[i, 0]) # Construct top_logprobs topk_token_ids = tokens[i, :].tolist() topk_logprobs = scores[i, :].tolist() sampled_rank = ranks[i].item() - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) + + if mtype == 3: # top_logprobs + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + elif mtype == 4: # draft_top_logprobs + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) + if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: @@ -442,7 +506,7 @@ def _process_batch_output(self): if not is_prefill or self.cfg.scheduler_config.name == "splitwise": batch_result.append(result) - self.postprocess(batch_result) + self.postprocess(batch_result, mtype) def _record_metrics(self, task, current_time, token_ids): """Record all metrics for a task""" diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py new file mode 100644 index 0000000000..0d487c00f1 --- /dev/null +++ b/tests/output/test_process_batch_output.py @@ -0,0 +1,167 @@ +import time +import unittest +from unittest.mock import Mock + +import paddle + +from fastdeploy.output.token_processor import TokenProcessor + +paddle.set_device("cpu") + + +# Mock classes and constants needed for the test +class MockConfig: + class ParallelConfig: + local_data_parallel_id = 0 + + class SpeculativeConfig: + method = None + + class ModelConfig: + enable_logprob = False + + class SchedulerConfig: + name = "default" + + parallel_config = ParallelConfig() + speculative_config = SpeculativeConfig() + model_config = ModelConfig() + scheduler_config = SchedulerConfig() + + +class MockTask: + def __init__(self): + self.request_id = "test_request_1" + self.arrival_time = time.time() + self.inference_start_time = time.time() + self.schedule_start_time = time.time() + self.preprocess_end_time = time.time() - 0.1 + self.preprocess_start_time = time.time() - 0.2 + self.eos_token_ids = [2] + self.output_token_ids = [] + self.messages = "Test prompt" + self.num_cached_tokens = 0 + self.disaggregate_info = None + self.prefill_chunk_info = None + self.prefill_chunk_num = 0 + + +class MockResourceManager: + def __init__(self): + self.stop_flags = [False] + self.tasks_list = [MockTask()] + self.to_be_rescheduled_request_id_set = set() + + def info(self): + return "Mock resource manager info" + + def reschedule_preempt_task(self, task_id): + pass + + +# Constants +RECOVERY_STOP_SIGNAL = -3 +MAX_BSZ = 512 +K = 20 +MAX_DRAFT_TOKENS = 6 +SPECULATE_MAX_BSZ = 256 + + +class TestTokenProcessorProcessBatchOutput(unittest.TestCase): + + def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): + """Helper method to setup TokenProcessor with different configurations""" + cfg = MockConfig() + cfg.speculative_config.method = "mtp" if speculative_decoding else None + cfg.model_config.enable_logprob = use_logprobs + + processor = TokenProcessor.__new__(TokenProcessor) + processor.cfg = cfg + processor.cached_generated_tokens = [] + processor.engine_worker_queue = Mock() + processor.split_connector = Mock() + processor.resource_manager = MockResourceManager() + processor.tokens_counter = {} + processor.total_step = 0 + processor.number_of_output_tokens = 0 + processor.prefill_result_status = {} + processor.executor = Mock() + + if speculative_decoding: + if use_logprobs: + processor.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], + fill_value=2, + dtype="int64", + ) + processor.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], + fill_value=0.0, + dtype="float32", + ) + processor.output_ranks = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS], + fill_value=0, + dtype="int64", + ) + else: + processor.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) + elif use_logprobs: + processor.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") + processor.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") + processor.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") + else: + processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") + + return processor + + def test_speculative_decoding_use_logprobs(self): + """Test basic speculative decoding scenario""" + processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True) + print(f"{processor}") + + # batch_size = 1 + # max_draft_tokens = MAX_DRAFT_TOKENS + + # # Setup speculative decoding output format + # output_tokens_np = np.full( + # SPECULATE_MAX_BSZ * max_draft_tokens + SPECULATE_MAX_BSZ + 10, + # 2, + # dtype=np.int64, + # ) + # output_tokens_np[1] = batch_size # batch size + # output_tokens_np[2:2 + batch_size] = [3] # accept numbers (3 accepted tokens) + + # # Setup draft tokens + # start_idx = 2 + SPECULATE_MAX_BSZ + # for i in range(batch_size): + # draft_tokens = np.arange(100, 100 + max_draft_tokens) + # output_tokens_np[ + # start_idx + i * max_draft_tokens:start_idx + (i + 1) * max_draft_tokens + # ] = draft_tokens + + # processor.output_tokens = paddle.to_tensor(output_tokens_np) + # processor.tokens_counter = {"test_request_1": 0} + # processor.postprocess = Mock() + + # # Mock speculative decoding metrics recording + # processor._record_speculative_decoding_mertics = Mock() + # processor._compute_speculative_status = Mock() + + # with patch.object(processor.resource_manager, "stop_flags", [False]): + # with patch.object(processor.resource_manager.tasks_list[0], "eos_token_ids", [2]): + # processor._process_batch_output() + + # self.assertTrue(processor._record_speculative_decoding_mertics.called) + # results = processor.postprocess.call_args[0][0] + # self.assertEqual(len(results), 1) + # # Should have 3 tokens (based on accept_num) + # self.assertEqual(len(results[0].outputs.token_ids), 3) + + +if __name__ == "__main__": + unittest.main(verbosity=2, buffer=False) From a5d2cc35f9d966365db551467ec595cf38f1cfb8 Mon Sep 17 00:00:00 2001 From: sunlei18 Date: Mon, 29 Sep 2025 01:50:44 +0800 Subject: [PATCH 08/10] feat: add draft_logprobs for Speculative Decode MTP --- fastdeploy/engine/request.py | 1 + fastdeploy/output/token_processor.py | 30 ++++-- tests/output/test_process_batch_output.py | 108 +++++++++++++--------- 3 files changed, 87 insertions(+), 52 deletions(-) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 0cade69734..9a4494cb2e 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -424,6 +424,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.output_type = output_type self.outputs = outputs self.finished = finished self.metrics = metrics diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 42a906f975..9481c286be 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -337,9 +337,15 @@ def _process_batch_output(self): mtype = self.output_tokens[1, 0] batch = self.output_tokens[2, 0] accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] - tokens = tokens[3 + batch : 3 + batch + batch * (K + 1) * MAX_DRAFT_TOKENS].reshape( - [batch, K + 1, MAX_DRAFT_TOKENS] + tokens = tokens[3 + batch : 3 + batch + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( + [batch, MAX_DRAFT_TOKENS, K + 1] ) + scores = ( + self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)] + .numpy() + .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) + ) + ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) else: batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] @@ -450,18 +456,23 @@ def _process_batch_output(self): if is_prefill and len(token_ids) > 1: result.outputs.draft_token_ids = copy.deepcopy(token_ids) - for token_id in token_ids: + for batch_token_index in range(len(token_ids)): + token_id = token_ids[batch_token_index] self.tokens_counter[task_id] += 1 if token_id != RECOVERY_STOP_SIGNAL: result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: - # TODO 投机解码场景兼容支持 - result.outputs.logprob = float(scores[i, 0]) - # Construct top_logprobs - topk_token_ids = tokens[i, :].tolist() - topk_logprobs = scores[i, :].tolist() - sampled_rank = ranks[i].item() + if self.cfg.speculative_config.method: + result.outputs.logprob = float(scores[batch_token_index, i, 0]) + topk_token_ids = tokens[batch_token_index, i, :].tolist() + topk_logprobs = scores[batch_token_index, i, :].tolist() + sampled_rank = ranks[batch_token_index, i].item() + else: + result.outputs.logprob = float(scores[i, 0]) + topk_token_ids = tokens[i, :].tolist() + topk_logprobs = scores[i, :].tolist() + sampled_rank = ranks[i].item() if mtype == 3: # top_logprobs if result.outputs.top_logprobs is None: @@ -485,7 +496,6 @@ def _process_batch_output(self): result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) - if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 0d487c00f1..2d31ca327d 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -45,6 +45,14 @@ def __init__(self): self.prefill_chunk_info = None self.prefill_chunk_num = 0 + def get(self, key: str, default_value=None): + if hasattr(self, key): + return getattr(self, key) + elif hasattr(self.sampling_params, key): + return getattr(self.sampling_params, key) + else: + return default_value + class MockResourceManager: def __init__(self): @@ -73,19 +81,36 @@ def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): """Helper method to setup TokenProcessor with different configurations""" cfg = MockConfig() cfg.speculative_config.method = "mtp" if speculative_decoding else None + cfg.speculative_config.num_speculative_tokens = 1 cfg.model_config.enable_logprob = use_logprobs processor = TokenProcessor.__new__(TokenProcessor) processor.cfg = cfg processor.cached_generated_tokens = [] + processor.executor = Mock() processor.engine_worker_queue = Mock() processor.split_connector = Mock() processor.resource_manager = MockResourceManager() - processor.tokens_counter = {} + task = MockTask() + processor.resource_manager.tasks_list = [task] + processor.tokens_counter = {task.request_id: 0} processor.total_step = 0 processor.number_of_output_tokens = 0 processor.prefill_result_status = {} - processor.executor = Mock() + processor.use_logprobs = use_logprobs + processor.num_draft_tokens = 0 + processor.num_accepted_tokens = 0 + processor.num_emitted_tokens = 0 + processor.max_num_emitted_tokens = 0 + processor.num_rest_requests_per_head = [ + 0, + ] * MAX_DRAFT_TOKENS + processor.num_accept_requests_per_head = [ + 0, + ] * MAX_DRAFT_TOKENS + processor.speculative_stats_step = 0 + + # processor._recycle_resources = Mock() if speculative_decoding: if use_logprobs: @@ -113,7 +138,7 @@ def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): elif use_logprobs: processor.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") processor.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") - processor.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") + processor.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") else: processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") @@ -122,45 +147,44 @@ def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): def test_speculative_decoding_use_logprobs(self): """Test basic speculative decoding scenario""" processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True) - print(f"{processor}") - - # batch_size = 1 - # max_draft_tokens = MAX_DRAFT_TOKENS - - # # Setup speculative decoding output format - # output_tokens_np = np.full( - # SPECULATE_MAX_BSZ * max_draft_tokens + SPECULATE_MAX_BSZ + 10, - # 2, - # dtype=np.int64, - # ) - # output_tokens_np[1] = batch_size # batch size - # output_tokens_np[2:2 + batch_size] = [3] # accept numbers (3 accepted tokens) - - # # Setup draft tokens - # start_idx = 2 + SPECULATE_MAX_BSZ - # for i in range(batch_size): - # draft_tokens = np.arange(100, 100 + max_draft_tokens) - # output_tokens_np[ - # start_idx + i * max_draft_tokens:start_idx + (i + 1) * max_draft_tokens - # ] = draft_tokens - - # processor.output_tokens = paddle.to_tensor(output_tokens_np) - # processor.tokens_counter = {"test_request_1": 0} - # processor.postprocess = Mock() - - # # Mock speculative decoding metrics recording - # processor._record_speculative_decoding_mertics = Mock() - # processor._compute_speculative_status = Mock() - - # with patch.object(processor.resource_manager, "stop_flags", [False]): - # with patch.object(processor.resource_manager.tasks_list[0], "eos_token_ids", [2]): - # processor._process_batch_output() - - # self.assertTrue(processor._record_speculative_decoding_mertics.called) - # results = processor.postprocess.call_args[0][0] - # self.assertEqual(len(results), 1) - # # Should have 3 tokens (based on accept_num) - # self.assertEqual(len(results[0].outputs.token_ids), 3) + + # stop_flag + processor.output_tokens[0, 0] = 2 + # mtype + processor.output_tokens[1, 0] = 3 # target = 3, decode = 4 + # batch + processor.output_tokens[2, 0] = 2 + # accept_num + processor.output_tokens[3, 0] = 3 + processor.output_tokens[4, 0] = 3 + + batch = processor.output_tokens[2, 0] + accept_num = [int(num[0]) for num in processor.output_tokens[3 : batch + 3]] + + # init + print(f"\nbatch: {batch}, accept_num: {accept_num}") + for i in range(batch): + for j in range(accept_num[i]): + for k in range(K + 1): + index = ( + 3 + + batch + + i * MAX_DRAFT_TOKENS * (K + 1) + + j * (K + 1) + + k + ) + print(f"i:{i}, j:{j} k:{k} index: {index}") + processor.output_tokens[index, 0] = 5 + i * 10 + j * 2 + k + processor.output_scores[i * MAX_DRAFT_TOKENS * (K + 1) + j * (K + 1) + k, 0] = float( + 0.1 * (5 + i * 10 + j * 2 + k) + ) + processor.output_ranks[i * MAX_DRAFT_TOKENS + j] = j + 1 + + print(f"{processor.output_tokens}") + print(f"{processor.output_scores}") + print(f"{processor.output_ranks}") + + # processor._process_batch_output() if __name__ == "__main__": From 33cb1cfe00357700e55c70f41ba08d0fa773051e Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Mon, 29 Sep 2025 18:27:56 +0800 Subject: [PATCH 09/10] fix some bugs --- fastdeploy/engine/request.py | 11 ++++ fastdeploy/entrypoints/openai/protocol.py | 4 ++ fastdeploy/entrypoints/openai/serving_chat.py | 17 ++++-- .../entrypoints/openai/serving_completion.py | 13 ++++- fastdeploy/output/token_processor.py | 56 +++++++++++++------ 5 files changed, 78 insertions(+), 23 deletions(-) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 9a4494cb2e..9af8d76e80 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -304,6 +304,7 @@ def to_dict(self): "token_ids": self.token_ids, "logprob": self.logprob, "top_logprobs": self.top_logprobs, + "draft_top_logprobs": self.draft_top_logprobs, "logprobs": self.logprobs, "draft_token_ids": self.draft_token_ids, "text": self.text, @@ -329,6 +330,8 @@ def __repr__(self) -> str: f"draft_token_ids={self.draft_token_ids}, " f"reasoning_content={self.reasoning_content!r}, " f"logprobs={self.logprobs}, " + f"top_logprobs={self.top_logprobs}, " + f"draft_top_logprobs={self.draft_top_logprobs}, " ) @@ -453,6 +456,14 @@ def add(self, next_output: RequestOutput) -> None: self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids) self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs) self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks) + if next_output.outputs.draft_top_logprobs is not None: + self.outputs.draft_top_logprobs.logprob_token_ids.extend( + next_output.outputs.draft_top_logprobs.logprob_token_ids + ) + self.outputs.draft_top_logprobs.logprobs.extend(next_output.outputs.draft_top_logprobs.logprobs) + self.outputs.draft_top_logprobs.sampled_token_ranks.extend( + next_output.outputs.draft_top_logprobs.sampled_token_ranks + ) def __repr__(self) -> str: return ( diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index f0805d697c..590a9a279b 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -184,6 +184,7 @@ class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] @@ -246,6 +247,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None arrival_time: Optional[float] = None @@ -278,6 +280,7 @@ class CompletionResponseChoice(BaseModel): completion_tokens: Optional[str] = None arrival_time: Optional[float] = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None reasoning_content: Optional[str] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None @@ -316,6 +319,7 @@ class CompletionResponseStreamChoice(BaseModel): text: str arrival_time: float = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None prompt_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None text_after_process: Optional[str] = None diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index c1e189a366..d261a650ba 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -293,6 +293,7 @@ async def chat_completion_stream_generator( output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None draft_logprobs_res: Optional[LogProbs] = None @@ -300,9 +301,9 @@ async def chat_completion_stream_generator( logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) - if request.include_draft_logprobs: + if request.include_draft_logprobs and output_draft_top_logprobs is not None: draft_logprobs_res = self._create_chat_logprobs( - output_top_logprobs, request.logprobs, request.draft_top_logprobs + output_draft_top_logprobs, request.logprobs, request.top_logprobs ) delta_message = DeltaMessage( @@ -426,6 +427,7 @@ async def chat_completion_full_generator( previous_num_tokens = 0 current_waiting_time = 0 logprob_contents = [] + draft_logprob_contents = [] completion_token_ids = [] response_processor = ChatResponseProcessor( data_processor=self.engine_client.data_processor, @@ -466,6 +468,7 @@ async def chat_completion_full_generator( # The logprob for handling the response output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: # logprobs logprobs_res = self._create_chat_logprobs( @@ -475,12 +478,12 @@ async def chat_completion_full_generator( logprob_contents.extend(logprobs_res.content) # draf_logprobs - if request.include_draft_logprobs: + if request.include_draft_logprobs and output_draft_top_logprobs is not None: draft_logprobs_res = self._create_chat_logprobs( - output_top_logprobs, request.logprobs, request.draft_top_logprobs + output_draft_top_logprobs, request.logprobs, request.top_logprobs ) if draft_logprobs_res and draft_logprobs_res.content is not None: - draft_logprobs_res.extend(logprobs_res.content) + draft_logprob_contents.extend(draft_logprobs_res.content) if data["finished"]: final_res = data @@ -515,11 +518,15 @@ async def chat_completion_full_generator( logprobs_full_res = None if logprob_contents: logprobs_full_res = LogProbs(content=logprob_contents) + draft_logprobs_full_res = None + if draft_logprob_contents: + draft_logprobs_full_res = LogProbs(content=draft_logprob_contents) choice = ChatCompletionResponseChoice( index=0, message=message, logprobs=logprobs_full_res, + draft_logprobs=draft_logprobs_full_res, finish_reason=None, ) has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index e0d88d5444..8b50bb743f 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -247,7 +247,7 @@ async def completion_full_generator( aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) # draft logprobs - if request.include_draft_logprobs: + if request.include_draft_logprobs and output_draft_top_logprobs is not None: aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) @@ -262,6 +262,7 @@ async def completion_full_generator( if data.get("finished", False): data["output_token_ids"] = output_tokens[rid] data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid] + data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid] data["outputs"]["token_ids"] = aggregated_token_ids[rid] valid_results[rid] = data num_choices -= 1 @@ -405,7 +406,7 @@ async def completion_stream_generator( logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) # draft logprobs - if request.include_draft_logprobs: + if request.include_draft_logprobs and output_draft_top_logprobs is not None: draft_logprobs_res = self._create_completion_logprobs( output_draft_top_logprobs, request.logprobs, 0 ) @@ -510,11 +511,18 @@ def request_output_to_completion_response( output = final_res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] aggregated_logprobs: Optional[CompletionLogprobs] = None if output_top_logprobs is not None: aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + aggregated_draft_logprobs: Optional[CompletionLogprobs] = None + if output_draft_top_logprobs is not None: + aggregated_draft_logprobs = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) + if request.echo: assert prompt_text is not None token_ids = [*prompt_token_ids, *output["token_ids"]] @@ -540,6 +548,7 @@ def request_output_to_completion_response( reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_call"), logprobs=aggregated_logprobs, + draft_logprobs=aggregated_draft_logprobs, finish_reason=finish_reason, ) choices.append(choice_data) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 9481c286be..5c7abd3b84 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -22,6 +22,7 @@ import weakref from collections import Counter from concurrent.futures import ThreadPoolExecutor +from typing import List import numpy as np @@ -159,6 +160,7 @@ def process_sampling_results(self): get_output_ep, get_output_topk, speculate_get_output, + speculate_get_output_topk, ) rank_id = self.cfg.parallel_config.local_data_parallel_id @@ -171,17 +173,35 @@ def process_sampling_results(self): and self.cfg.parallel_config.data_parallel_size > 1 ): if self.use_logprobs: - # TODO speculate_get_output_with_topk - pass + speculate_get_output_topk( + self.output_tokens, + self.output_scores, + self.output_ranks, + K, + rank_id, + True, + ) + if self.output_tokens[0][0] == -2: + continue else: speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + if self.output_tokens[0] == -2: + continue elif self.use_logprobs: - # TODO speculate_get_output_with_topk - pass + speculate_get_output_topk( + self.output_tokens, + self.output_scores, + self.output_ranks, + K, + rank_id, + True, + ) + if self.output_tokens[0][0] == -2: + continue else: speculate_get_output(self.output_tokens, rank_id, is_blocking, False) - if self.output_tokens[0] == -2: - continue + if self.output_tokens[0] == -2: + continue else: if self.use_logprobs: @@ -227,7 +247,7 @@ def process_metrics(): self.executor.submit(process_metrics) - def postprocess(self, batch_result, mtype=3): + def postprocess(self, batch_result: List[RequestOutput], mtype=3): """ single post-processing function @@ -235,14 +255,18 @@ def postprocess(self, batch_result, mtype=3): batch_result (list): batch results """ try: - if self.cfg.speculative_config.method and self.cfg.use_logprobs: + if self.cfg.speculative_config.method and self.use_logprobs: if mtype == 3: # target - self._batch_result_buffer = batch_result + has_finished = any(r.finished for r in batch_result) + if has_finished: + self.cached_generated_tokens.put_results(batch_result) + else: + self._batch_result_buffer = batch_result elif mtype == 4: # draft target_batch_result = [] draft_batch_result = batch_result for target, decode in zip(self._batch_result_buffer, draft_batch_result): - target["outputs"]["draft_top_logprobs"] = decode["outputs"]["draft_top_logprobs"] + target.outputs.draft_top_logprobs = decode.outputs.draft_top_logprobs target_batch_result.append(target) self._batch_result_buffer = None self.cached_generated_tokens.put_results(target_batch_result) @@ -334,10 +358,10 @@ def _process_batch_output(self): mtype = 3 if self.cfg.speculative_config.method: if self.use_logprobs: - mtype = self.output_tokens[1, 0] + mtype = int(self.output_tokens[1, 0].item()) batch = self.output_tokens[2, 0] accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] - tokens = tokens[3 + batch : 3 + batch + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( + tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( [batch, MAX_DRAFT_TOKENS, K + 1] ) scores = ( @@ -464,10 +488,10 @@ def _process_batch_output(self): task.output_token_ids.append(token_id) if self.use_logprobs: if self.cfg.speculative_config.method: - result.outputs.logprob = float(scores[batch_token_index, i, 0]) - topk_token_ids = tokens[batch_token_index, i, :].tolist() - topk_logprobs = scores[batch_token_index, i, :].tolist() - sampled_rank = ranks[batch_token_index, i].item() + result.outputs.logprob = float(scores[i, batch_token_index, 0]) + topk_token_ids = tokens[i, batch_token_index, :].tolist() + topk_logprobs = scores[i, batch_token_index, :].tolist() + sampled_rank = ranks[i, batch_token_index].item() else: result.outputs.logprob = float(scores[i, 0]) topk_token_ids = tokens[i, :].tolist() From 54e74547dd770c837ad8e8531d94660f139adf9f Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Mon, 29 Sep 2025 18:42:46 +0800 Subject: [PATCH 10/10] fix codestyle --- tests/output/test_process_batch_output.py | 191 ---------------------- 1 file changed, 191 deletions(-) delete mode 100644 tests/output/test_process_batch_output.py diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py deleted file mode 100644 index 2d31ca327d..0000000000 --- a/tests/output/test_process_batch_output.py +++ /dev/null @@ -1,191 +0,0 @@ -import time -import unittest -from unittest.mock import Mock - -import paddle - -from fastdeploy.output.token_processor import TokenProcessor - -paddle.set_device("cpu") - - -# Mock classes and constants needed for the test -class MockConfig: - class ParallelConfig: - local_data_parallel_id = 0 - - class SpeculativeConfig: - method = None - - class ModelConfig: - enable_logprob = False - - class SchedulerConfig: - name = "default" - - parallel_config = ParallelConfig() - speculative_config = SpeculativeConfig() - model_config = ModelConfig() - scheduler_config = SchedulerConfig() - - -class MockTask: - def __init__(self): - self.request_id = "test_request_1" - self.arrival_time = time.time() - self.inference_start_time = time.time() - self.schedule_start_time = time.time() - self.preprocess_end_time = time.time() - 0.1 - self.preprocess_start_time = time.time() - 0.2 - self.eos_token_ids = [2] - self.output_token_ids = [] - self.messages = "Test prompt" - self.num_cached_tokens = 0 - self.disaggregate_info = None - self.prefill_chunk_info = None - self.prefill_chunk_num = 0 - - def get(self, key: str, default_value=None): - if hasattr(self, key): - return getattr(self, key) - elif hasattr(self.sampling_params, key): - return getattr(self.sampling_params, key) - else: - return default_value - - -class MockResourceManager: - def __init__(self): - self.stop_flags = [False] - self.tasks_list = [MockTask()] - self.to_be_rescheduled_request_id_set = set() - - def info(self): - return "Mock resource manager info" - - def reschedule_preempt_task(self, task_id): - pass - - -# Constants -RECOVERY_STOP_SIGNAL = -3 -MAX_BSZ = 512 -K = 20 -MAX_DRAFT_TOKENS = 6 -SPECULATE_MAX_BSZ = 256 - - -class TestTokenProcessorProcessBatchOutput(unittest.TestCase): - - def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): - """Helper method to setup TokenProcessor with different configurations""" - cfg = MockConfig() - cfg.speculative_config.method = "mtp" if speculative_decoding else None - cfg.speculative_config.num_speculative_tokens = 1 - cfg.model_config.enable_logprob = use_logprobs - - processor = TokenProcessor.__new__(TokenProcessor) - processor.cfg = cfg - processor.cached_generated_tokens = [] - processor.executor = Mock() - processor.engine_worker_queue = Mock() - processor.split_connector = Mock() - processor.resource_manager = MockResourceManager() - task = MockTask() - processor.resource_manager.tasks_list = [task] - processor.tokens_counter = {task.request_id: 0} - processor.total_step = 0 - processor.number_of_output_tokens = 0 - processor.prefill_result_status = {} - processor.use_logprobs = use_logprobs - processor.num_draft_tokens = 0 - processor.num_accepted_tokens = 0 - processor.num_emitted_tokens = 0 - processor.max_num_emitted_tokens = 0 - processor.num_rest_requests_per_head = [ - 0, - ] * MAX_DRAFT_TOKENS - processor.num_accept_requests_per_head = [ - 0, - ] * MAX_DRAFT_TOKENS - processor.speculative_stats_step = 0 - - # processor._recycle_resources = Mock() - - if speculative_decoding: - if use_logprobs: - processor.output_tokens = paddle.full( - shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], - fill_value=2, - dtype="int64", - ) - processor.output_scores = paddle.full( - shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], - fill_value=0.0, - dtype="float32", - ) - processor.output_ranks = paddle.full( - shape=[MAX_BSZ * MAX_DRAFT_TOKENS], - fill_value=0, - dtype="int64", - ) - else: - processor.output_tokens = paddle.full( - shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], - fill_value=2, - dtype="int64", - ) - elif use_logprobs: - processor.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") - processor.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") - processor.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") - else: - processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") - - return processor - - def test_speculative_decoding_use_logprobs(self): - """Test basic speculative decoding scenario""" - processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True) - - # stop_flag - processor.output_tokens[0, 0] = 2 - # mtype - processor.output_tokens[1, 0] = 3 # target = 3, decode = 4 - # batch - processor.output_tokens[2, 0] = 2 - # accept_num - processor.output_tokens[3, 0] = 3 - processor.output_tokens[4, 0] = 3 - - batch = processor.output_tokens[2, 0] - accept_num = [int(num[0]) for num in processor.output_tokens[3 : batch + 3]] - - # init - print(f"\nbatch: {batch}, accept_num: {accept_num}") - for i in range(batch): - for j in range(accept_num[i]): - for k in range(K + 1): - index = ( - 3 - + batch - + i * MAX_DRAFT_TOKENS * (K + 1) - + j * (K + 1) - + k - ) - print(f"i:{i}, j:{j} k:{k} index: {index}") - processor.output_tokens[index, 0] = 5 + i * 10 + j * 2 + k - processor.output_scores[i * MAX_DRAFT_TOKENS * (K + 1) + j * (K + 1) + k, 0] = float( - 0.1 * (5 + i * 10 + j * 2 + k) - ) - processor.output_ranks[i * MAX_DRAFT_TOKENS + j] = j + 1 - - print(f"{processor.output_tokens}") - print(f"{processor.output_scores}") - print(f"{processor.output_ranks}") - - # processor._process_batch_output() - - -if __name__ == "__main__": - unittest.main(verbosity=2, buffer=False)