Skip to content

Commit 2b4748d

Browse files
[MTP] refactor MTP pre_process (#6358)
1 parent 18e79dd commit 2b4748d

24 files changed

+411
-533
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,9 @@ paddle::Tensor RebuildPaddingFunc(
428428
const paddle::Tensor& seq_len_this_time,
429429
const paddle::Tensor& seq_lens_decoder,
430430
const paddle::Tensor& seq_lens_encoder,
431-
const paddle::optional<paddle::Tensor>& output_padding_offset,
431+
const paddle::optional<paddle::Tensor>& batch_id_per_token_output,
432+
const paddle::optional<paddle::Tensor>& cu_seqlens_q_output,
432433
const paddle::optional<paddle::Tensor>& first_token_out,
433-
int max_input_length,
434434
bool enable_logprob);
435435

436436
void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
@@ -747,28 +747,23 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
747747
const paddle::Tensor& seq_lens_encoder,
748748
const paddle::Tensor& seq_lens_decoder);
749749

750-
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
751-
const paddle::Tensor& output_cum_offsets_tmp,
752-
const paddle::Tensor& out_token_num,
753-
const paddle::Tensor& seq_lens_output,
750+
void SpecTokenPenaltyMultiScores(
751+
const paddle::Tensor& pre_ids,
752+
const paddle::Tensor& logits,
753+
const paddle::Tensor& penalty_scores,
754+
const paddle::Tensor& frequency_scores,
755+
const paddle::Tensor& presence_scores,
756+
const paddle::Tensor& temperatures,
757+
const paddle::Tensor& bad_tokens,
758+
const paddle::Tensor& bad_tokens_len,
759+
const paddle::Tensor& cur_len,
760+
const paddle::Tensor& min_len,
761+
const paddle::Tensor& eos_token_id,
762+
const paddle::Tensor& seq_lens_this_time,
763+
const paddle::Tensor& batch_id_per_token_output,
764+
const paddle::Tensor& cu_seqlens_q_output,
754765
const int max_seq_len);
755766

756-
void SpecTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
757-
const paddle::Tensor& logits,
758-
const paddle::Tensor& penalty_scores,
759-
const paddle::Tensor& frequency_scores,
760-
const paddle::Tensor& presence_scores,
761-
const paddle::Tensor& temperatures,
762-
const paddle::Tensor& bad_tokens,
763-
const paddle::Tensor& bad_tokens_len,
764-
const paddle::Tensor& cur_len,
765-
const paddle::Tensor& min_len,
766-
const paddle::Tensor& eos_token_id,
767-
const paddle::Tensor& seq_lens_this_time,
768-
const paddle::Tensor& output_padding_offset,
769-
const paddle::Tensor& output_cum_offsets,
770-
const int max_seq_len);
771-
772767
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens,
773768
const paddle::Tensor& accept_num,
774769
const paddle::Tensor& pre_ids,
@@ -794,7 +789,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
794789
const paddle::Tensor& max_dec_len,
795790
const paddle::Tensor& end_tokens,
796791
const paddle::Tensor& is_block_step,
797-
const paddle::Tensor& output_cum_offsets,
792+
const paddle::Tensor& cu_seqlens_q_output,
798793
const paddle::Tensor& actual_candidate_len,
799794
const paddle::Tensor& actual_draft_token_nums,
800795
const paddle::Tensor& topp,
@@ -922,7 +917,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
922917
const paddle::Tensor& seq_lens_encoder,
923918
const paddle::Tensor& seq_lens_decoder,
924919
const paddle::Tensor& step_idx,
925-
const paddle::Tensor& output_cum_offsets,
920+
const paddle::Tensor& cu_seqlens_q_output,
926921
const paddle::Tensor& stop_flags,
927922
const paddle::Tensor& not_need_stop,
928923
const paddle::Tensor& max_dec_len,
@@ -1102,19 +1097,20 @@ std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
11021097

11031098
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);
11041099

1105-
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
1106-
const paddle::Tensor& pre_ids,
1107-
const paddle::Tensor& stop_flags,
1108-
const paddle::Tensor& seq_lens_this_time,
1109-
const paddle::Tensor& seq_lens_encoder,
1110-
const paddle::Tensor& step_idx,
1111-
const paddle::Tensor& allowed_tokens,
1112-
const paddle::Tensor& reasoning_status,
1113-
const paddle::Tensor& output_padding_offset,
1114-
const paddle::Tensor& output_cum_offsets,
1115-
const paddle::Tensor& enable_thinking,
1116-
int64_t think_end_id,
1117-
int64_t line_break_id);
1100+
void ReasoningPhaseTokenConstraint(
1101+
const paddle::Tensor& logits,
1102+
const paddle::Tensor& pre_ids,
1103+
const paddle::Tensor& stop_flags,
1104+
const paddle::Tensor& seq_lens_this_time,
1105+
const paddle::Tensor& seq_lens_encoder,
1106+
const paddle::Tensor& step_idx,
1107+
const paddle::Tensor& allowed_tokens,
1108+
const paddle::Tensor& reasoning_status,
1109+
const paddle::Tensor& batch_id_per_token_output,
1110+
const paddle::Tensor& cu_seqlens_q_output,
1111+
const paddle::Tensor& enable_thinking,
1112+
int64_t think_end_id,
1113+
int64_t line_break_id);
11181114

11191115
std::vector<paddle::Tensor> get_attn_mask_q(
11201116
const paddle::Tensor& cu_seqlens_q,
@@ -1612,10 +1608,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
16121608
&SpeculateGetSeqLensOutput,
16131609
"speculate_get_seq_lens_output function");
16141610

1615-
m.def("speculate_get_output_padding_offset",
1616-
&SpeculateGetOutputPaddingOffset,
1617-
"speculate_get_output_padding_offset function");
1618-
16191611
m.def("speculate_get_token_penalty_multi_scores",
16201612
&SpecTokenPenaltyMultiScores,
16211613
"speculate_get_token_penalty_multi_scores function");

custom_ops/gpu_ops/reasoning_phase_token_constraint.cu

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,17 @@ __global__ void apply_token_enforce_generation_scores_kernel(
125125
T* __restrict__ logits_dst, // logits (output)
126126
const int64_t* __restrict__ allowed_tokens, // [allowed_len]
127127
const int32_t* __restrict__ reasoning_status,
128-
const int* output_padding_offset,
129-
const int* output_cum_offsets,
128+
const int* batch_id_per_token_output,
129+
const int* cu_seqlens_q_output,
130130
const int max_bsz,
131131
const int max_seq_len,
132132
const int vocab_size,
133133
const int allowed_tokens_len) {
134134
int token_idx = blockIdx.x;
135135
int tid = threadIdx.x;
136136

137-
const int bs_idx =
138-
(token_idx + output_padding_offset[token_idx]) / max_seq_len;
139-
const int query_start_token_idx =
140-
bs_idx * max_seq_len - output_cum_offsets[bs_idx];
137+
const int bs_idx = batch_id_per_token_output[token_idx];
138+
const int query_start_token_idx = cu_seqlens_q_output[bs_idx];
141139
bool is_batch_first_token = (token_idx == query_start_token_idx);
142140

143141
if (allowed_tokens_len == 0 || !is_batch_first_token) {
@@ -177,8 +175,8 @@ void reasoning_phase_token_constraint(
177175
const paddle::Tensor& step_idx,
178176
const paddle::Tensor& allowed_tokens,
179177
const paddle::Tensor& reasoning_status,
180-
const paddle::Tensor& output_padding_offset,
181-
const paddle::Tensor& output_cum_offsets,
178+
const paddle::Tensor& batch_id_per_token_output,
179+
const paddle::Tensor& cu_seqlens_q_output,
182180
const paddle::Tensor& enable_thinking,
183181
int64_t think_end_id,
184182
int64_t line_break_id) {
@@ -233,27 +231,28 @@ void reasoning_phase_token_constraint(
233231
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
234232
allowed_tokens.data<int64_t>(),
235233
reasoning_status.data<int32_t>(),
236-
output_padding_offset.data<int32_t>(),
237-
output_cum_offsets.data<int32_t>(),
234+
batch_id_per_token_output.data<int32_t>(),
235+
cu_seqlens_q_output.data<int32_t>(),
238236
bs,
239237
max_seq_len,
240238
vocab_size,
241239
allowed_tokens_len);
242240
}
243241

244-
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
245-
const paddle::Tensor& pre_ids,
246-
const paddle::Tensor& stop_flags,
247-
const paddle::Tensor& seq_lens_this_time,
248-
const paddle::Tensor& seq_lens_encoder,
249-
const paddle::Tensor& step_idx,
250-
const paddle::Tensor& allowed_tokens,
251-
const paddle::Tensor& reasoning_status,
252-
const paddle::Tensor& output_padding_offset,
253-
const paddle::Tensor& output_cum_offsets,
254-
const paddle::Tensor& enable_thinking,
255-
int64_t think_end_id,
256-
int64_t line_break_id) {
242+
void ReasoningPhaseTokenConstraint(
243+
const paddle::Tensor& logits,
244+
const paddle::Tensor& pre_ids,
245+
const paddle::Tensor& stop_flags,
246+
const paddle::Tensor& seq_lens_this_time,
247+
const paddle::Tensor& seq_lens_encoder,
248+
const paddle::Tensor& step_idx,
249+
const paddle::Tensor& allowed_tokens,
250+
const paddle::Tensor& reasoning_status,
251+
const paddle::Tensor& batch_id_per_token_output,
252+
const paddle::Tensor& cu_seqlens_q_output,
253+
const paddle::Tensor& enable_thinking,
254+
int64_t think_end_id,
255+
int64_t line_break_id) {
257256
switch (logits.type()) {
258257
case paddle::DataType::FLOAT16:
259258
return reasoning_phase_token_constraint<paddle::DataType::FLOAT16>(
@@ -265,8 +264,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
265264
step_idx,
266265
allowed_tokens,
267266
reasoning_status,
268-
output_padding_offset,
269-
output_cum_offsets,
267+
batch_id_per_token_output,
268+
cu_seqlens_q_output,
270269
enable_thinking,
271270
think_end_id,
272271
line_break_id);
@@ -280,8 +279,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
280279
step_idx,
281280
allowed_tokens,
282281
reasoning_status,
283-
output_padding_offset,
284-
output_cum_offsets,
282+
batch_id_per_token_output,
283+
cu_seqlens_q_output,
285284
enable_thinking,
286285
think_end_id,
287286
line_break_id);
@@ -295,8 +294,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
295294
step_idx,
296295
allowed_tokens,
297296
reasoning_status,
298-
output_padding_offset,
299-
output_cum_offsets,
297+
batch_id_per_token_output,
298+
cu_seqlens_q_output,
300299
enable_thinking,
301300
think_end_id,
302301
line_break_id);
@@ -317,8 +316,8 @@ PD_BUILD_STATIC_OP(reasoning_phase_token_constraint)
317316
"step_idx",
318317
"allowed_tokens",
319318
"reasoning_status",
320-
"output_padding_offset",
321-
"output_cum_offsets",
319+
"batch_id_per_token_output",
320+
"cu_seqlens_q_output",
322321
"enable_thinking"})
323322
.Outputs({"logits_out", "reasoning_status_out"})
324323
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})

0 commit comments

Comments
 (0)