diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 1ced2ce6fb..2c1c4580e3 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -332,7 +332,9 @@ paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length); + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob); void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &stop_flags, @@ -891,6 +893,39 @@ void SaveOutMmsgStatic(const paddle::Tensor& x, int64_t rank_id, bool save_each_rank); +std::vector SpeculateGetLogits( + const paddle::Tensor &logits, + const paddle::Tensor &first_token_logits, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder); + +void SpeculateInsertFirstToken(const paddle::Tensor &token_ids, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &next_tokens, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &cu_batch_token_offset, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder); + +void SpeculateGetTargetLogits(const paddle::Tensor &target_logits, + const paddle::Tensor &logits, + const paddle::Tensor &cu_batch_token_offset, + const paddle::Tensor &ori_cu_batch_token_offset, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &accept_num); + +void SpeculateSaveOutMmsgTopK(const paddle::Tensor &sampled_token_ids, + const paddle::Tensor &logprob_token_ids, + const paddle::Tensor &logprob_scores, + const paddle::Tensor &logprob_ranks, + const paddle::Tensor &token_num_per_batch, + const paddle::Tensor &cu_batch_token_offset, + const paddle::Tensor ¬_need_stop, + int message_flag, + int64_t rank_id); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -1277,4 +1312,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function"); m.def("save_output", &SaveOutMmsgStatic, "save_output function"); + + m.def("speculate_get_logits", &SpeculateGetLogits, "speculate_get_logits function"); + + m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function"); + + m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function"); + + m.def("speculate_save_output_topk", &SpeculateSaveOutMmsgTopK, "speculate_save_output_topk function"); } diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index 93c1bb38c2..130b6d6062 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -46,6 +46,7 @@ __global__ void RebuildPaddingKernel(T *output_data, template __global__ void RebuildAppendPaddingKernel(T *output_data, + T *first_token_out, const T *input_data, const int *cu_seqlens_q, const int *seq_len_this_time, @@ -55,7 +56,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, const int max_input_length, const int dim_embed, const int64_t output_elem_nums, - const int bsz) { + const int bsz, + const bool enable_logprob) { AlignedVector src_vec; const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; for (int64_t i = global_idx * VecSize; i < output_elem_nums; @@ -77,6 +79,15 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, Load(&input_data[input_token_id * dim_embed + bias_idx], &src_vec); Store(src_vec, &output_data[i]); + + if (enable_logprob && seq_len_encoder[bi] > 0) { + int first_token_seq_id = seq_len_encoder[bi] - 2; + const int first_token_id = + ori_token_id - cum_offset_bi + first_token_seq_id; + Load(&input_data[first_token_id * dim_embed + bias_idx], + &src_vec); + Store(src_vec, &first_token_out[i]); + } } } @@ -89,7 +100,9 @@ std::vector rebuild_padding( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -135,6 +148,10 @@ std::vector rebuild_padding( RebuildAppendPaddingKernel <<>>( reinterpret_cast(out.data()), + first_token_out.is_initialized() + ? reinterpret_cast(const_cast( + first_token_out.get_ptr()->data())) + : nullptr, reinterpret_cast(tmp_out.data()), cu_seqlens_q.data(), seq_len_this_time.data(), @@ -144,7 +161,8 @@ std::vector rebuild_padding( max_input_length, dim_embed, elem_nums, - bsz); + bsz, + enable_logprob); } else { RebuildPaddingKernel <<>>( @@ -169,7 +187,9 @@ paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob) { switch (tmp_out.type()) { case paddle::DataType::BFLOAT16: { return rebuild_padding( @@ -179,7 +199,9 @@ paddle::Tensor RebuildPaddingFunc( seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)[0]; + first_token_out, + max_input_length, + enable_logprob)[0]; } case paddle::DataType::FLOAT16: { return rebuild_padding( @@ -189,7 +211,9 @@ paddle::Tensor RebuildPaddingFunc( seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)[0]; + first_token_out, + max_input_length, + enable_logprob)[0]; } case paddle::DataType::FLOAT32: { return rebuild_padding( @@ -199,7 +223,9 @@ paddle::Tensor RebuildPaddingFunc( seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)[0]; + first_token_out, + max_input_length, + enable_logprob)[0]; } default: { PD_THROW( @@ -217,14 +243,18 @@ std::vector RebuildPadding( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob) { return {RebuildPaddingFunc(tmp_out, cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)}; + first_token_out, + max_input_length, + enable_logprob)}; } std::vector> RebuildPaddingInferShape( @@ -260,9 +290,10 @@ PD_BUILD_STATIC_OP(rebuild_padding) "seq_len_this_time", "seq_lens_decoder", "seq_lens_encoder", - paddle::Optional("output_padding_offset")}) + paddle::Optional("output_padding_offset"), + paddle::Optional("first_token_out")}) .Outputs({"out"}) - .Attrs({"max_input_length: int"}) + .Attrs({"max_input_length: int", "enable_logprob: bool"}) .SetKernelFn(PD_KERNEL(RebuildPadding)) .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc new file mode 100644 index 0000000000..f7ca8733d1 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +#define MAX_BSZ 512 +#define K 20 +#define MAX_DRAFT_TOKEN_NUM 6 + +struct batch_msgdata { + int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; +}; + +struct msgdata { + long mtype; + int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums + batch_msgdata mtext[MAX_BSZ]; +}; + +void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, + const paddle::Tensor& output_scores, + const paddle::Tensor& output_ranks, + int real_k, + int64_t rank_id, + bool wait_flag) { + struct msgdata msg_rcv; + int msg_queue_id = 1; + + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str( + inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); +#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + msg_queue_id = inference_msg_queue_id_from_env; + } + static key_t key = ftok("/dev/shm", msg_queue_id); + + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG + std::cout << "get_output_key: " << key << std::endl; + std::cout << "get_output msgid: " << msgid << std::endl; +#endif + + int64_t* output_tokens_data = + const_cast(output_tokens.data()); + float* output_scores_data = const_cast(output_scores.data()); + int64_t* output_ranks_data = + const_cast(output_ranks.data()); + int ret = -1; + if (!wait_flag) { + ret = msgrcv( + msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, IPC_NOWAIT); + } else { + ret = msgrcv(msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, 0); + } + if (ret == -1) { + // read none + output_tokens_data[0] = -2; // stop_flag + output_tokens_data[1] = 0; // message_flag, Target: 3, Draft: 4 + output_tokens_data[2] = 0; // bsz + return; + } + + int bsz = msg_rcv.meta[1]; + output_tokens_data[0] = (int64_t)msg_rcv.meta[0]; + output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; + output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; + + int output_tokens_offset = 3 + MAX_BSZ; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_rcv.meta[3 + i]; + output_tokens_data[3 + i] = (int64_t)cur_token_num; // batch_token_nums + + auto* cur_output_token = output_tokens_data + output_tokens_offset + + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + auto* cur_output_score = + output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; + for (int j = 0; j < cur_token_num; j++) { + for (int k = 0; k < real_k + 1; k++) { + cur_output_token[j * (K + 1) + k] = + (int64_t)cur_batch_msg_rcv->tokens[j * (K + 1) + k]; + cur_output_score[j * (K + 1) + k] = + cur_batch_msg_rcv->scores[j * (K + 1) + k]; + } + output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] = + (int64_t)cur_batch_msg_rcv->ranks[j]; + } + } + return; +} + +PD_BUILD_STATIC_OP(speculate_get_output_topk) + .Inputs({"output_tokens", "output_scores", "output_ranks"}) + .Attrs({"real_k: int", "rank_id: int64_t", "wait_flag: bool"}) + .Outputs({"output_tokens_out", "output_scores_out", "output_ranks_out"}) + .SetInplaceMap({{"output_tokens", "output_tokens_out"}, + {"output_scores", "output_scores_out"}, + {"output_ranks", "output_ranks_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetOutMmsgTopK)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu new file mode 100644 index 0000000000..f40e099a1e --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu @@ -0,0 +1,285 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +template +__global__ void get_token_num_per_batch_kernel(int* batch_token_num, + int* total_token_num, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz) { + int bid = threadIdx.x; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int token_num_now = 0; + if (bid < real_bsz) { + token_num_now = seq_lens_encoder[bid] > 0 ? 2 : seq_lens_this_time[bid]; + batch_token_num[bid] = token_num_now; + } + + __syncthreads(); + int token_num_sum = BlockReduce(temp_storage).Sum(token_num_now); + if (bid == 0) { + total_token_num[0] = token_num_sum; + } +} + +template +__global__ void speculate_get_logits_kernel(float* draft_logits, + const float* logits, + const float* first_token_logits, + const int* cu_seqlens_q, + const int* cu_batch_token_offset, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int vocab_size, + const int real_bsz) { + AlignedVector src_vec; + const int bid = blockIdx.x; + const int tid = threadIdx.x; + if (bid < real_bsz) { + auto* draft_logits_now = + draft_logits + cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + cu_seqlens_q[bid] * vocab_size; + for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) { + if (seq_lens_encoder[bid] > 0) { + Load(&first_token_logits[bid * vocab_size + i], + &src_vec); + Store(src_vec, &draft_logits_now[i]); + + Load(&logits_now[i], &src_vec); + Store(src_vec, + &draft_logits_now[vocab_size + i]); + } else { + for (int j = 0; j < seq_lens_this_time[bid]; j++) { + Load(&logits_now[j * vocab_size + i], + &src_vec); + Store( + src_vec, &draft_logits_now[j * vocab_size + i]); + } + } + } + } +} + +std::vector SpeculateGetLogits( + const paddle::Tensor& logits, + const paddle::Tensor& first_token_logits, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder) { + auto cu_stream = seq_lens_this_time.stream(); + const int vocab_size = logits.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + auto total_token_num = paddle::full( + {1}, 0, paddle::DataType::INT32, seq_lens_this_time.place()); + auto batch_token_num = paddle::full( + {real_bsz}, 0, paddle::DataType::INT32, seq_lens_this_time.place()); + + constexpr int THREADBLOCK_SIZE = 512; + get_token_num_per_batch_kernel + <<<1, THREADBLOCK_SIZE, 0, cu_stream>>>(batch_token_num.data(), + total_token_num.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + real_bsz); + + auto total_token_num_cpu = + total_token_num.copy_to(paddle::CPUPlace(), true); + + auto draft_logits = + paddle::empty({total_token_num_cpu.data()[0], vocab_size}, + paddle::DataType::FLOAT32, + seq_lens_this_time.place()); + auto cu_batch_token_offset = paddle::full( + {real_bsz + 1}, 0, paddle::DataType::INT32, seq_lens_this_time.place()); + + void* temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum(temp_storage, + temp_storage_bytes, + batch_token_num.data(), + &cu_batch_token_offset.data()[1], + real_bsz, + cu_stream); + cudaMalloc(&temp_storage, temp_storage_bytes); + cub::DeviceScan::InclusiveSum(temp_storage, + temp_storage_bytes, + batch_token_num.data(), + &cu_batch_token_offset.data()[1], + real_bsz, + cu_stream); + + constexpr int PackSize = VEC_16B / sizeof(float); + dim3 grid_dim(real_bsz); + dim3 block_dim(128); + speculate_get_logits_kernel + <<>>( + const_cast(draft_logits.data()), + logits.data(), + first_token_logits.data(), + cu_seqlens_q.data(), + cu_batch_token_offset.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + vocab_size, + real_bsz); + + return {draft_logits, batch_token_num, cu_batch_token_offset}; +} + +__global__ void speculate_insert_first_token_kernel( + int64_t* token_ids, + const int64_t* accept_tokens, + const int64_t* next_tokens, + const int* cu_seqlens_q, + const int* cu_batch_token_offset, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int max_draft_tokens, + const int real_bsz) { + const int bid = threadIdx.x; + + auto* token_ids_now = token_ids + cu_batch_token_offset[bid]; + auto* accept_tokens_now = accept_tokens + bid * max_draft_tokens; + auto* next_tokens_now = next_tokens + cu_seqlens_q[bid]; + if (seq_lens_encoder[bid] != 0) { + token_ids_now[0] = accept_tokens_now[0]; + token_ids_now[1] = next_tokens_now[0]; + } else { + for (int i = 0; i < seq_lens_this_time[bid]; i++) { + token_ids_now[i] = next_tokens_now[i]; + } + } +} + +void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& next_tokens, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder) { + auto cu_stream = seq_lens_this_time.stream(); + const int max_draft_tokens = accept_tokens.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + speculate_insert_first_token_kernel<<<1, real_bsz, 0, cu_stream>>>( + const_cast(token_ids.data()), + accept_tokens.data(), + next_tokens.data(), + cu_seqlens_q.data(), + cu_batch_token_offset.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_draft_tokens, + real_bsz); +} + +template +__global__ void speculate_get_target_logits_kernel( + float* target_logtis, + const float* logits, + const int* cu_batch_token_offset, + const int* ori_cu_batch_token_offset, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* accept_num, + const int vocab_size, + const int real_bsz) { + AlignedVector src_vec; + const int bid = blockIdx.x; + const int tid = threadIdx.x; + if (bid < real_bsz) { + auto* target_logtis_now = + target_logtis + cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size; + for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) { + if (seq_lens_encoder[bid] > 0) { + Load(&logits_now[i], &src_vec); + Store(src_vec, &target_logtis_now[i]); + } else { + for (int j = 0; j < accept_num[bid]; j++) { + Load(&logits_now[j * vocab_size + i], + &src_vec); + Store( + src_vec, &target_logtis_now[j * vocab_size + i]); + } + } + } + } +} + +void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& ori_cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num) { + auto cu_stream = seq_lens_this_time.stream(); + const int vocab_size = logits.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + constexpr int PackSize = VEC_16B / sizeof(float); + dim3 grid_dim(real_bsz); + dim3 block_dim(128); + speculate_get_target_logits_kernel + <<>>( + const_cast(target_logits.data()), + logits.data(), + cu_batch_token_offset.data(), + ori_cu_batch_token_offset.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + accept_num.data(), + vocab_size, + real_bsz); +} + +PD_BUILD_STATIC_OP(speculate_get_logits) + .Inputs({"logits", + "first_token_logits", + "cu_seqlens_q", + "seq_lens_this_time", + "seq_lens_encoder"}) + .Outputs({"draft_logits", "batch_token_num", "cu_batch_token_offset"}) + .SetKernelFn(PD_KERNEL(SpeculateGetLogits)); + +PD_BUILD_STATIC_OP(speculate_insert_first_token) + .Inputs({"token_ids", + "accept_tokens", + "next_tokens", + "cu_seqlens_q", + "cu_batch_token_offset", + "seq_lens_this_time", + "seq_lens_encoder"}) + .Outputs({"token_ids_out"}) + .SetInplaceMap({{"token_ids", "token_ids_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateInsertFirstToken)); + +PD_BUILD_STATIC_OP(speculate_get_target_logits) + .Inputs({"target_logits", + "logits", + "cu_batch_token_offset", + "ori_cu_batch_token_offset", + "seq_lens_this_time", + "seq_lens_encoder", + "accept_num"}) + .Outputs({"target_logits_out"}) + .SetInplaceMap({{"target_logits", "target_logits_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetTargetLogits)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc new file mode 100644 index 0000000000..78eb6c1d48 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +#define MAX_BSZ 512 +#define K 20 +#define MAX_DRAFT_TOKEN_NUM 6 + +struct batch_msgdata { + int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; +}; + +struct msgdata { + long mtype; + int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums + batch_msgdata mtext[MAX_BSZ]; +}; + +void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, + const paddle::Tensor& logprob_token_ids, + const paddle::Tensor& logprob_scores, + const paddle::Tensor& logprob_ranks, + const paddle::Tensor& token_num_per_batch, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& not_need_stop, + int message_flag, // Target: 3, Draft: 4 + int64_t rank_id) { + if (rank_id > 0) { + return; + } + auto sampled_token_ids_cpu = + sampled_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_token_ids_cpu = + logprob_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false); + auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false); + auto token_num_per_batch_cpu = + token_num_per_batch.copy_to(paddle::CPUPlace(), false); + auto cu_batch_token_offset_cpu = + cu_batch_token_offset.copy_to(paddle::CPUPlace(), false); + int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data(); + int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data(); + float* logprob_scores_data = logprob_scores_cpu.data(); + int64_t* logprob_ranks_data = logprob_ranks_cpu.data(); + int* token_num_per_batch_data = token_num_per_batch_cpu.data(); + int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data(); + + static struct msgdata msg_sed; + int msg_queue_id = 1; + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str( + inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); + msg_queue_id = inference_msg_queue_id_from_env; +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." + << std::endl; +#endif + } + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is perserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout + << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } + static key_t key = ftok("/dev/shm", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output_key: " << key << std::endl; + std::cout << "save msgid: " << msgid << std::endl; +#endif + msg_sed.mtype = 1; + msg_sed.meta[0] = not_need_stop.data()[0] + ? inference_msg_id_from_env + : -inference_msg_id_from_env; + msg_sed.meta[1] = message_flag; + int bsz = token_num_per_batch.shape()[0]; + msg_sed.meta[2] = bsz; + int max_num_logprobs = logprob_token_ids.shape()[1]; + for (int i = 0; i < bsz; i++) { + int cur_token_num = token_num_per_batch_data[i]; + msg_sed.meta[3 + i] = cur_token_num; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + int token_offset = cu_batch_token_offset_data[i]; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + for (int k = 0; k < K + 1; k++) { + if (k == 0) { + cur_tokens[k] = + (int)sampled_token_ids_data[token_offset + j]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (K + 1) + k]; + } else if (k < max_num_logprobs) { + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * (K + 1) + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (K + 1) + k]; + } else { + cur_tokens[k] = -1; + cur_scores[k] = 0.0; + } + } + cur_batch_msg_sed->ranks[j] = + (int)logprob_ranks_data[token_offset + j]; + } + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "msg data: " << std::endl; + std::cout << "stop_flag: " << msg_sed.meta[0] + << ", message_flag: " << msg_sed.meta[1] + << ", bsz: " << msg_sed.meta[2] << std::endl; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_sed.meta[3 + i]; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + std::cout << "batch " << i << " token_num: " << cur_token_num + << std::endl; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + std::cout << "tokens: "; + for (int k = 0; k < K + 1; k++) { + std::cout << cur_tokens[k] << " "; + } + std::cout << std::endl; + std::cout << "scores: "; + for (int k = 0; k < K + 1; k++) { + std::cout << cur_scores[k] << " "; + } + std::cout << std::endl; + std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl; + } + } + std::cout << std::endl; +#endif + if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) { + printf("full msg buffer\n"); + } +} + +PD_BUILD_STATIC_OP(speculate_save_output_topk) + .Inputs({ + "sampled_token_ids", + "logprob_token_ids", + "logprob_scores", + "logprob_ranks", + "token_num_per_batch", + "cu_batch_token_offset", + "not_need_stop", + }) + .Attrs({"message_flag: int", "rank_id: int64_t"}) + .SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK)); diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 2f3710bfa4..e75396530b 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -403,8 +403,6 @@ def __post_init__(self): if self.dynamic_load_weight: self.enable_prefix_caching = False if self.enable_logprob: - if self.speculative_config is not None: - raise NotImplementedError("Logprob does not support speculation_config.") if not current_platform.is_cuda(): raise NotImplementedError("Only CUDA platform supports logprob.") if self.splitwise_role != "mixed": diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 04a2276afb..9af8d76e80 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -287,6 +287,7 @@ class CompletionOutput: token_ids: list[int] logprob: Optional[float] = None top_logprobs: Optional[LogprobsLists] = None + draft_top_logprobs: Optional[LogprobsLists] = None logprobs: Optional[SampleLogprobs] = None draft_token_ids: list[int] = None text: Optional[str] = None @@ -303,6 +304,7 @@ def to_dict(self): "token_ids": self.token_ids, "logprob": self.logprob, "top_logprobs": self.top_logprobs, + "draft_top_logprobs": self.draft_top_logprobs, "logprobs": self.logprobs, "draft_token_ids": self.draft_token_ids, "text": self.text, @@ -328,6 +330,8 @@ def __repr__(self) -> str: f"draft_token_ids={self.draft_token_ids}, " f"reasoning_content={self.reasoning_content!r}, " f"logprobs={self.logprobs}, " + f"top_logprobs={self.top_logprobs}, " + f"draft_top_logprobs={self.draft_top_logprobs}, " ) @@ -412,6 +416,7 @@ def __init__( request_id: str, prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, + output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, metrics: Optional[RequestMetrics] = None, @@ -422,6 +427,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.output_type = output_type self.outputs = outputs self.finished = finished self.metrics = metrics @@ -450,12 +456,21 @@ def add(self, next_output: RequestOutput) -> None: self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids) self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs) self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks) + if next_output.outputs.draft_top_logprobs is not None: + self.outputs.draft_top_logprobs.logprob_token_ids.extend( + next_output.outputs.draft_top_logprobs.logprob_token_ids + ) + self.outputs.draft_top_logprobs.logprobs.extend(next_output.outputs.draft_top_logprobs.logprobs) + self.outputs.draft_top_logprobs.sampled_token_ranks.extend( + next_output.outputs.draft_top_logprobs.sampled_token_ranks + ) def __repr__(self) -> str: return ( f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"output_type={self.output_type}, " f"outputs={self.outputs}, " f"finished={self.finished}, " f"num_cached_tokens={self.num_cached_tokens}, " @@ -476,6 +491,7 @@ def to_dict(self): "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, + "output_type": self.output_type, "outputs": None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(), "finished": self.finished, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index b74e0ffb46..590a9a279b 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -184,6 +184,7 @@ class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] @@ -246,6 +247,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None arrival_time: Optional[float] = None @@ -278,6 +280,7 @@ class CompletionResponseChoice(BaseModel): completion_tokens: Optional[str] = None arrival_time: Optional[float] = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None reasoning_content: Optional[str] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None @@ -316,6 +319,7 @@ class CompletionResponseStreamChoice(BaseModel): text: str arrival_time: float = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None prompt_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None text_after_process: Optional[str] = None @@ -405,6 +409,7 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = None logprobs: Optional[int] = None + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False @@ -540,6 +545,7 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 125d785fe3..d261a650ba 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -293,12 +293,18 @@ async def chat_completion_stream_generator( output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None + draft_logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_chat_logprobs( + output_draft_top_logprobs, request.logprobs, request.top_logprobs + ) delta_message = DeltaMessage( reasoning_content="", @@ -326,6 +332,7 @@ async def chat_completion_stream_generator( index=0, delta=delta_message, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, arrival_time=arrival_time, ) if res["finished"]: @@ -420,6 +427,7 @@ async def chat_completion_full_generator( previous_num_tokens = 0 current_waiting_time = 0 logprob_contents = [] + draft_logprob_contents = [] completion_token_ids = [] response_processor = ChatResponseProcessor( data_processor=self.engine_client.data_processor, @@ -460,12 +468,23 @@ async def chat_completion_full_generator( # The logprob for handling the response output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: + # logprobs logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) if logprobs_res and logprobs_res.content is not None: logprob_contents.extend(logprobs_res.content) + + # draf_logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_chat_logprobs( + output_draft_top_logprobs, request.logprobs, request.top_logprobs + ) + if draft_logprobs_res and draft_logprobs_res.content is not None: + draft_logprob_contents.extend(draft_logprobs_res.content) + if data["finished"]: final_res = data task_is_finished = True @@ -499,11 +518,15 @@ async def chat_completion_full_generator( logprobs_full_res = None if logprob_contents: logprobs_full_res = LogProbs(content=logprob_contents) + draft_logprobs_full_res = None + if draft_logprob_contents: + draft_logprobs_full_res = LogProbs(content=draft_logprob_contents) choice = ChatCompletionResponseChoice( index=0, message=message, logprobs=logprobs_full_res, + draft_logprobs=draft_logprobs_full_res, finish_reason=None, ) has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 9b089d073d..8b50bb743f 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -212,6 +212,7 @@ async def completion_full_generator( valid_results = [dict()] * num_choices output_tokens = [0] * num_choices aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] + aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 @@ -239,11 +240,18 @@ async def completion_full_generator( output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0]) aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1]) aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) + # draft logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) + aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) + aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) + aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) self.engine_client.data_processor.process_response_dict( @@ -254,6 +262,7 @@ async def completion_full_generator( if data.get("finished", False): data["output_token_ids"] = output_tokens[rid] data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid] + data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid] data["outputs"]["token_ids"] = aggregated_token_ids[rid] valid_results[rid] = data num_choices -= 1 @@ -390,10 +399,17 @@ async def completion_stream_generator( await self._echo_back_prompt(request, res, idx) output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] logprobs_res: Optional[CompletionLogprobs] = None + draft_logprobs_res: Optional[CompletionLogprobs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + # draft logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) output_tokens[idx] += 1 delta_message = CompletionResponseStreamChoice( index=idx, @@ -406,6 +422,7 @@ async def completion_stream_generator( reasoning_content="", arrival_time=arrival_time, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, ) if not res["finished"] and "delta_message" in output: delta_message_output = output["delta_message"] @@ -494,11 +511,18 @@ def request_output_to_completion_response( output = final_res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] aggregated_logprobs: Optional[CompletionLogprobs] = None if output_top_logprobs is not None: aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + aggregated_draft_logprobs: Optional[CompletionLogprobs] = None + if output_draft_top_logprobs is not None: + aggregated_draft_logprobs = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) + if request.echo: assert prompt_text is not None token_ids = [*prompt_token_ids, *output["token_ids"]] @@ -524,6 +548,7 @@ def request_output_to_completion_response( reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_call"), logprobs=aggregated_logprobs, + draft_logprobs=aggregated_draft_logprobs, finish_reason=finish_reason, ) choices.append(choice_data) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 5aecfa1f9e..8c5e7190d1 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -36,6 +36,10 @@ min_p_sampling, top_k_top_p_sampling, ) +from fastdeploy.model_executor.ops.gpu import ( + speculate_get_target_logits, + speculate_insert_first_token, +) from fastdeploy.platforms import current_platform from fastdeploy.worker.output import LogprobsTensors, SamplerOutput @@ -375,6 +379,7 @@ def __init__(self, fd_config: FDConfig): self.speculative_verify_window = fd_config.speculative_config.verify_window self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode + self.speculative_tokens_num = fd_config.speculative_config.num_speculative_tokens def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" @@ -389,6 +394,98 @@ def apply_logits_processor( """apply logits processor to sampler""" pass + def compute_logprobs( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + ) -> paddle.Tensor: + """compute logprobs""" + share_inputs = sampling_metadata.share_inputs + last_logits = logits + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + batch_token_num = share_inputs["batch_token_num"] + + temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs + top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs + if temp_scaled_logprobs is not None: + real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] + temperature = sampling_metadata.temperature[:real_bsz] + real_bsz_temp_scaled = ( + real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool") + ) + temperature = temperature.squeeze(1).repeat_interleave(batch_token_num) + temp_temperature = paddle.where( + real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) + ).unsqueeze(1) + last_logits = last_logits / temp_temperature + + last_logprobs = F.log_softmax(last_logits, axis=-1) + top_p_logprob = None + top_p_token_mask = None + + if top_p_normalized_logprobs is not None and share_inputs is not None: + real_token_top_p = ( + sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1) + ) + top_p_normalized_logprobs = ( + top_p_normalized_logprobs[:real_bsz] + .astype("int32") + .squeeze(1) + .repeat_interleave(batch_token_num) + .astype("bool") + .unsqueeze(1) + ) + top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) + if top_p_token_mask.any(): + probs = F.softmax(last_logits, axis=-1) + probs = top_p_normalize_probs_paddle(probs, real_token_top_p) + top_p_logprob = paddle.log(probs) + if top_p_logprob is not None: + last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) + return last_logprobs + + def gather_logprobs( + self, + logprobs: paddle.Tensor, + num_logprobs: int, + token_ids: paddle.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + Args: + logprobs: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + Must be int64. + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + assert token_ids.dtype == paddle.int64 + token_ids = token_ids.unsqueeze(1) + logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) + # Get with the logprob of the prompt or sampled token. + token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + if num_logprobs >= 1: + # Find the topK values. + topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) + indices = paddle.concat([token_ids, topk_indices], axis=1) + top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) + else: + indices = token_ids + top_logprobs = token_logprobs + + return LogprobsTensors(indices, top_logprobs, token_ranks) + def forward_cuda( self, logits: paddle.Tensor, @@ -453,7 +550,53 @@ def forward_cuda( self.speculative_benchmark_mode, ) - return None + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None: + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + batch_token_num = paddle.where( + share_inputs["seq_lens_encoder"][:real_bsz] != 0, + paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), + share_inputs["accept_num"][:real_bsz].unsqueeze(1), + ).squeeze(1) + share_inputs["batch_token_num"] = batch_token_num + ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( + "int32" + ) + cu_batch_token_offset = paddle.concat( + [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"])] + ).astype("int32") + share_inputs["cu_batch_token_offset"] = cu_batch_token_offset + target_logtis = paddle.empty([share_inputs["accept_num"].sum(), logits.shape[1]], dtype=logits.dtype) + speculate_get_target_logits( + target_logtis, + logits, + cu_batch_token_offset, + ori_cu_batch_token_offset, + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + share_inputs["accept_num"], + ) + raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata) + + sampler_output = None + if num_logprobs is not None: + + token_ids = share_inputs["accept_tokens"] + token_ids = paddle.concat( + [ + share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] + for i in range(share_inputs["accept_num"].shape[0]) + ] + ) + logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + + sampler_output = SamplerOutput( + sampled_token_ids=token_ids, + logprobs_tensors=logprobs_tensors, + token_num_per_batch=batch_token_num, + ) + + return sampler_output class MTPSampler(nn.Layer): @@ -466,6 +609,7 @@ def __init__(self, fd_config: FDConfig): self.forward = self.forward_cuda else: raise NotImplementedError + self.speculative_tokens_num = fd_config.speculative_config.num_speculative_tokens def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" @@ -480,6 +624,103 @@ def apply_logits_processor( """apply logits processor to sampler""" pass + def compute_logprobs( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + ) -> paddle.Tensor: + """compute logprobs""" + share_inputs = sampling_metadata.share_inputs + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + last_logits = logits + temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs + top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs + if temp_scaled_logprobs is not None: + real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] + temperature = sampling_metadata.temperature[:real_bsz] + real_bsz_temp_scaled = ( + real_bsz_temp_scaled.astype("int32") + .squeeze(1) + .repeat_interleave(share_inputs["batch_token_num"]) + .astype("bool") + ) + temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"]) + temp_temperature = paddle.where( + real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) + ).unsqueeze(1) + last_logits = last_logits / temp_temperature + + last_logprobs = F.log_softmax(last_logits, axis=-1) + top_p_logprob = None + top_p_token_mask = None + + if top_p_normalized_logprobs is not None and share_inputs is not None: + real_token_top_p = ( + sampling_metadata.top_p[:real_bsz] + .squeeze(1) + .repeat_interleave(share_inputs["batch_token_num"]) + .unsqueeze(1) + ) + top_p_normalized_logprobs = ( + top_p_normalized_logprobs[:real_bsz] + .astype("int32") + .squeeze(1) + .repeat_interleave(share_inputs["batch_token_num"]) + .astype("bool") + .unsqueeze(1) + ) + top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) + + if top_p_token_mask.any(): + probs = F.softmax(last_logits, axis=-1) + probs = top_p_normalize_probs_paddle(probs, real_token_top_p) + top_p_logprob = paddle.log(probs) + if top_p_logprob is not None: + last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) + return last_logprobs + + def gather_logprobs( + self, + logprobs: paddle.Tensor, + num_logprobs: int, + token_ids: paddle.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + Args: + logprobs: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + Must be int64. + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + assert token_ids.dtype == paddle.int64 + token_ids = token_ids.unsqueeze(1) + logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) + # Get with the logprob of the prompt or sampled token. + token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + if num_logprobs >= 1: + # Find the topK values. + topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) + indices = paddle.concat([token_ids, topk_indices], axis=1) + top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) + else: + indices = token_ids + top_logprobs = token_logprobs + + return LogprobsTensors(indices, top_logprobs, token_ranks) + def forward_cuda( self, logits: paddle.Tensor, @@ -488,6 +729,10 @@ def forward_cuda( share_inputs: List[paddle.Tensor], ) -> paddle.Tensor: """ """ + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None and share_inputs["substep"] == 0: + raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"], sampling_metadata) + logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, @@ -509,4 +754,26 @@ def forward_cuda( _, next_tokens = top_k_top_p_sampling( probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list ) - return next_tokens + + sampler_output = None + if num_logprobs is not None and share_inputs["substep"] == 0: + token_ids = paddle.empty(share_inputs["batch_token_num"].sum(), dtype="int64") + speculate_insert_first_token( + token_ids, + share_inputs["accept_tokens"], + next_tokens, + share_inputs["cu_seqlens_q"], + share_inputs["cu_batch_token_offset"], + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + ) + + logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + + sampler_output = SamplerOutput( + sampled_token_ids=token_ids, + logprobs_tensors=logprobs_tensors, + token_num_per_batch=share_inputs["batch_token_num"], + ) + + return next_tokens, sampler_output diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 796476de7f..04ef4f009d 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -529,6 +529,8 @@ def rebuild_padding( seq_lens_encoder: paddle.Tensor, output_padding_offset: Optional[paddle.Tensor] = None, max_input_length: Optional[int] = None, + first_token_out: Optional[paddle.Tensor] = None, + enable_logprob: Optional[bool] = False, ): """ Args: @@ -544,7 +546,9 @@ def rebuild_padding( seq_lens_decoder, seq_lens_encoder, output_padding_offset, + first_token_out, max_input_length, + enable_logprob, ) elif current_platform.is_dcu(): from fastdeploy.model_executor.ops.gpu import rebuild_padding diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 73585ef776..5c7abd3b84 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -22,6 +22,7 @@ import weakref from collections import Counter from concurrent.futures import ThreadPoolExecutor +from typing import List import numpy as np @@ -60,11 +61,20 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.use_logprobs = self.cfg.model_config.enable_logprob if self.speculative_decoding: - self.output_tokens = paddle.full( - shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], - fill_value=2, - dtype="int64", - ) + if self.use_logprobs: + self.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], fill_value=2, dtype="int64" + ) + self.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32" + ) + self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64") + else: + self.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) elif self.use_logprobs: self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") @@ -100,6 +110,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.executor = ThreadPoolExecutor(max_workers=1) self.prefill_result_status = dict() self._finalizer = weakref.finalize(self, self._cleanup_resources) + self._batch_result_buffer = None def _cleanup_resources(self): """Cleaning up shared memory resources""" @@ -149,6 +160,7 @@ def process_sampling_results(self): get_output_ep, get_output_topk, speculate_get_output, + speculate_get_output_topk, ) rank_id = self.cfg.parallel_config.local_data_parallel_id @@ -156,9 +168,40 @@ def process_sampling_results(self): try: is_blocking = True if self.speculative_decoding: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) - if self.output_tokens[0] == -2: - continue + if ( + self.cfg.parallel_config.enable_expert_parallel + and self.cfg.parallel_config.data_parallel_size > 1 + ): + if self.use_logprobs: + speculate_get_output_topk( + self.output_tokens, + self.output_scores, + self.output_ranks, + K, + rank_id, + True, + ) + if self.output_tokens[0][0] == -2: + continue + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + if self.output_tokens[0] == -2: + continue + elif self.use_logprobs: + speculate_get_output_topk( + self.output_tokens, + self.output_scores, + self.output_ranks, + K, + rank_id, + True, + ) + if self.output_tokens[0][0] == -2: + continue + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if self.output_tokens[0] == -2: + continue else: if self.use_logprobs: @@ -204,7 +247,7 @@ def process_metrics(): self.executor.submit(process_metrics) - def postprocess(self, batch_result): + def postprocess(self, batch_result: List[RequestOutput], mtype=3): """ single post-processing function @@ -212,7 +255,25 @@ def postprocess(self, batch_result): batch_result (list): batch results """ try: - self.cached_generated_tokens.put_results(batch_result) + if self.cfg.speculative_config.method and self.use_logprobs: + if mtype == 3: # target + has_finished = any(r.finished for r in batch_result) + if has_finished: + self.cached_generated_tokens.put_results(batch_result) + else: + self._batch_result_buffer = batch_result + elif mtype == 4: # draft + target_batch_result = [] + draft_batch_result = batch_result + for target, decode in zip(self._batch_result_buffer, draft_batch_result): + target.outputs.draft_top_logprobs = decode.outputs.draft_top_logprobs + target_batch_result.append(target) + self._batch_result_buffer = None + self.cached_generated_tokens.put_results(target_batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") @@ -293,9 +354,25 @@ def _process_batch_output(self): tokens = self.output_tokens.numpy() scores = None ranks = None + # target:3, draft:4 + mtype = 3 if self.cfg.speculative_config.method: - batch = self.output_tokens[1] - accept_num = tokens[2 : batch + 2] + if self.use_logprobs: + mtype = int(self.output_tokens[1, 0].item()) + batch = self.output_tokens[2, 0] + accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] + tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( + [batch, MAX_DRAFT_TOKENS, K + 1] + ) + scores = ( + self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)] + .numpy() + .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) + ) + ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) + else: + batch = self.output_tokens[1] + accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) elif self.use_logprobs: batch = self.output_tokens[1, 0] @@ -323,19 +400,24 @@ def _process_batch_output(self): task_id = task.request_id if self.cfg.speculative_config.method: - token_ids = tokens[ - 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS : 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS - + accept_num[i] - ].tolist() - if len(token_ids) == 0 or token_ids[-1] <= 0: - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - if task_id in self.resource_manager.to_be_rescheduled_request_id_set: - self.resource_manager.reschedule_preempt_task(task_id) - continue + if accept_num[i] == -3: + recovery_stop = True + if recovery_stop: + llm_logger.info(f"recovery stop signal found at task {task_id}") + token_ids = [RECOVERY_STOP_SIGNAL] + elif self.use_logprobs: + token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + else: + token_ids = tokens[ + 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS + + accept_num[i] + ].tolist() + if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] @@ -378,6 +460,7 @@ def _process_batch_output(self): self._record_metrics(task, current_time, token_ids) result = RequestOutput( request_id=task_id, + output_type=mtype, outputs=CompletionOutput( index=i, send_idx=self.tokens_counter[task_id], @@ -397,22 +480,46 @@ def _process_batch_output(self): if is_prefill and len(token_ids) > 1: result.outputs.draft_token_ids = copy.deepcopy(token_ids) - for token_id in token_ids: + for batch_token_index in range(len(token_ids)): + token_id = token_ids[batch_token_index] self.tokens_counter[task_id] += 1 if token_id != RECOVERY_STOP_SIGNAL: result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: - result.outputs.logprob = float(scores[i, 0]) - # Construct top_logprobs - topk_token_ids = tokens[i, :].tolist() - topk_logprobs = scores[i, :].tolist() - sampled_rank = ranks[i].item() - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) + if self.cfg.speculative_config.method: + result.outputs.logprob = float(scores[i, batch_token_index, 0]) + topk_token_ids = tokens[i, batch_token_index, :].tolist() + topk_logprobs = scores[i, batch_token_index, :].tolist() + sampled_rank = ranks[i, batch_token_index].item() + else: + result.outputs.logprob = float(scores[i, 0]) + topk_token_ids = tokens[i, :].tolist() + topk_logprobs = scores[i, :].tolist() + sampled_rank = ranks[i].item() + + if mtype == 3: # top_logprobs + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + elif mtype == 4: # draft_top_logprobs + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: @@ -433,7 +540,7 @@ def _process_batch_output(self): if not is_prefill or self.cfg.scheduler_config.name == "splitwise": batch_result.append(result) - self.postprocess(batch_result) + self.postprocess(batch_result, mtype) def _record_metrics(self, task, current_time, token_ids): """Record all metrics for a task""" diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 6ec6ee1906..a02fbc05d3 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -41,6 +41,8 @@ mtp_save_first_token, mtp_step_paddle, share_external_data, + speculate_get_logits, + speculate_save_output_topk, ) from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding @@ -62,6 +64,7 @@ def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs): self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps + self.enable_logprob = self.model_config.enable_logprob # [mixed, prefill, decoder] self.role = "mixed" @@ -354,6 +357,13 @@ def _init_model_inputs(self): self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" ) self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() + self.model_inputs["temp_scaled_logprobs"] = self.target_model_inputs["temp_scaled_logprobs"] + self.model_inputs["top_p_normalized_logprobs"] = self.target_model_inputs["top_p_normalized_logprobs"] + self.model_inputs["accept_num"] = self.target_model_inputs["accept_num"] + self.model_inputs["accept_tokens"] = self.target_model_inputs["accept_tokens"] + self.model_inputs["draft_logits"] = self.target_model_inputs["draft_logits"] + max_num_seqs = self.model_inputs["seq_lens_encoder"].shape[0] + self.model_inputs["first_token_hidden_states"] = paddle.full([max_num_seqs, self.model_config.hidden_size], -1) def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): @@ -657,6 +667,10 @@ def _propose(self, target_hidden_states): min_dec_lens=self.model_inputs["min_dec_len"], bad_words_token_ids=self.model_inputs["bad_tokens"], eos_token_ids=self.model_inputs["eos_token_id"], + max_num_logprobs=20 if self.enable_logprob else None, + temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], + top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], + share_inputs=self.model_inputs, ) if self.num_model_steps > 1: @@ -668,6 +682,11 @@ def _propose(self, target_hidden_states): forward_meta=self.forward_meta, ) + if self.enable_logprob and substep == 0: + first_token_hidden_states = paddle.empty( + [self.max_num_seqs, self.model_config.hidden_size], dtype=model_output.dtype + ) + hidden_states = rebuild_padding( model_output, self.model_inputs["cu_seqlens_q"], @@ -676,18 +695,46 @@ def _propose(self, target_hidden_states): self.model_inputs["seq_lens_encoder"], self.model_inputs["output_padding_offset"], self.parallel_config.max_model_len, + first_token_hidden_states if substep == 0 else None, + self.enable_logprob if substep == 0 else False, ) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) - - sampled_token_ids = self.sampler( + if self.enable_logprob and substep == 0: + first_token_logits = self.model.compute_logits(first_token_hidden_states) + + draft_logits, batch_token_num, cu_batch_token_offset = speculate_get_logits( + logits, + first_token_logits, + self.model_inputs["cu_seqlens_q"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + ) + self.model_inputs["draft_logits"] = draft_logits + self.model_inputs["batch_token_num"] = batch_token_num + self.model_inputs["cu_batch_token_offset"] = cu_batch_token_offset + + sampled_token_ids, sampler_output = self.sampler( logits, self.sampling_metadata, self.max_model_len, self.model_inputs, ) + if substep == 0: + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + batch_token_num, + cu_batch_token_offset, + self.model_inputs["not_need_stop"], + 4, # mtype + self.local_rank, + ) + if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast(sampled_token_ids, 0) diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 6d820a873a..a4176771d0 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -106,6 +106,7 @@ class SamplerOutput: # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding. sampled_token_ids: paddle.Tensor logprobs_tensors: Optional[LogprobsTensors] + token_num_per_batch: Optional[paddle.Tensor] = None @dataclass