@@ -428,9 +428,9 @@ paddle::Tensor RebuildPaddingFunc(
428428 const paddle::Tensor& seq_len_this_time,
429429 const paddle::Tensor& seq_lens_decoder,
430430 const paddle::Tensor& seq_lens_encoder,
431- const paddle::optional<paddle::Tensor>& output_padding_offset,
431+ const paddle::optional<paddle::Tensor>& batch_id_per_token_output,
432+ const paddle::optional<paddle::Tensor>& cu_seqlens_q_output,
432433 const paddle::optional<paddle::Tensor>& first_token_out,
433- int max_input_length,
434434 bool enable_logprob);
435435
436436void GetStopFlagsMulti (const paddle::Tensor& topk_ids,
@@ -747,28 +747,23 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
747747 const paddle::Tensor& seq_lens_encoder,
748748 const paddle::Tensor& seq_lens_decoder);
749749
750- std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset (
751- const paddle::Tensor& output_cum_offsets_tmp,
752- const paddle::Tensor& out_token_num,
753- const paddle::Tensor& seq_lens_output,
750+ void SpecTokenPenaltyMultiScores (
751+ const paddle::Tensor& pre_ids,
752+ const paddle::Tensor& logits,
753+ const paddle::Tensor& penalty_scores,
754+ const paddle::Tensor& frequency_scores,
755+ const paddle::Tensor& presence_scores,
756+ const paddle::Tensor& temperatures,
757+ const paddle::Tensor& bad_tokens,
758+ const paddle::Tensor& bad_tokens_len,
759+ const paddle::Tensor& cur_len,
760+ const paddle::Tensor& min_len,
761+ const paddle::Tensor& eos_token_id,
762+ const paddle::Tensor& seq_lens_this_time,
763+ const paddle::Tensor& batch_id_per_token_output,
764+ const paddle::Tensor& cu_seqlens_q_output,
754765 const int max_seq_len);
755766
756- void SpecTokenPenaltyMultiScores (const paddle::Tensor& pre_ids,
757- const paddle::Tensor& logits,
758- const paddle::Tensor& penalty_scores,
759- const paddle::Tensor& frequency_scores,
760- const paddle::Tensor& presence_scores,
761- const paddle::Tensor& temperatures,
762- const paddle::Tensor& bad_tokens,
763- const paddle::Tensor& bad_tokens_len,
764- const paddle::Tensor& cur_len,
765- const paddle::Tensor& min_len,
766- const paddle::Tensor& eos_token_id,
767- const paddle::Tensor& seq_lens_this_time,
768- const paddle::Tensor& output_padding_offset,
769- const paddle::Tensor& output_cum_offsets,
770- const int max_seq_len);
771-
772767void SpecGetStopFlagsMultiSeqs (const paddle::Tensor& accept_tokens,
773768 const paddle::Tensor& accept_num,
774769 const paddle::Tensor& pre_ids,
@@ -794,7 +789,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
794789 const paddle::Tensor& max_dec_len,
795790 const paddle::Tensor& end_tokens,
796791 const paddle::Tensor& is_block_step,
797- const paddle::Tensor& output_cum_offsets ,
792+ const paddle::Tensor& cu_seqlens_q_output ,
798793 const paddle::Tensor& actual_candidate_len,
799794 const paddle::Tensor& actual_draft_token_nums,
800795 const paddle::Tensor& topp,
@@ -922,7 +917,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
922917 const paddle::Tensor& seq_lens_encoder,
923918 const paddle::Tensor& seq_lens_decoder,
924919 const paddle::Tensor& step_idx,
925- const paddle::Tensor& output_cum_offsets ,
920+ const paddle::Tensor& cu_seqlens_q_output ,
926921 const paddle::Tensor& stop_flags,
927922 const paddle::Tensor& not_need_stop,
928923 const paddle::Tensor& max_dec_len,
@@ -1102,19 +1097,20 @@ std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
11021097
11031098std::vector<paddle::Tensor> GeluTanh (paddle::Tensor& input);
11041099
1105- void ReasoningPhaseTokenConstraint (const paddle::Tensor& logits,
1106- const paddle::Tensor& pre_ids,
1107- const paddle::Tensor& stop_flags,
1108- const paddle::Tensor& seq_lens_this_time,
1109- const paddle::Tensor& seq_lens_encoder,
1110- const paddle::Tensor& step_idx,
1111- const paddle::Tensor& allowed_tokens,
1112- const paddle::Tensor& reasoning_status,
1113- const paddle::Tensor& output_padding_offset,
1114- const paddle::Tensor& output_cum_offsets,
1115- const paddle::Tensor& enable_thinking,
1116- int64_t think_end_id,
1117- int64_t line_break_id);
1100+ void ReasoningPhaseTokenConstraint (
1101+ const paddle::Tensor& logits,
1102+ const paddle::Tensor& pre_ids,
1103+ const paddle::Tensor& stop_flags,
1104+ const paddle::Tensor& seq_lens_this_time,
1105+ const paddle::Tensor& seq_lens_encoder,
1106+ const paddle::Tensor& step_idx,
1107+ const paddle::Tensor& allowed_tokens,
1108+ const paddle::Tensor& reasoning_status,
1109+ const paddle::Tensor& batch_id_per_token_output,
1110+ const paddle::Tensor& cu_seqlens_q_output,
1111+ const paddle::Tensor& enable_thinking,
1112+ int64_t think_end_id,
1113+ int64_t line_break_id);
11181114
11191115std::vector<paddle::Tensor> get_attn_mask_q (
11201116 const paddle::Tensor& cu_seqlens_q,
@@ -1612,10 +1608,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
16121608 &SpeculateGetSeqLensOutput,
16131609 " speculate_get_seq_lens_output function" );
16141610
1615- m.def (" speculate_get_output_padding_offset" ,
1616- &SpeculateGetOutputPaddingOffset,
1617- " speculate_get_output_padding_offset function" );
1618-
16191611 m.def (" speculate_get_token_penalty_multi_scores" ,
16201612 &SpecTokenPenaltyMultiScores,
16211613 " speculate_get_token_penalty_multi_scores function" );
0 commit comments