diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index b0e6e332b3..1ba7ba75f8 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -259,7 +259,8 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &step_idx, - const paddle::Tensor &stop_flags); + const paddle::Tensor &stop_flags, + const paddle::optional &decode_states); paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &tmp_out, // [token_num, dim_embed] @@ -279,6 +280,7 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &step_idx, const paddle::Tensor &stop_seqs, const paddle::Tensor &stop_seqs_len, + const paddle::optional &src_batch_ids, const bool beam_search); diff --git a/custom_ops/gpu_ops/set_value_by_flags.cu b/custom_ops/gpu_ops/set_value_by_flags.cu index 38d2ea0456..36852654b6 100644 --- a/custom_ops/gpu_ops/set_value_by_flags.cu +++ b/custom_ops/gpu_ops/set_value_by_flags.cu @@ -25,11 +25,16 @@ __global__ void set_value_by_flag_and_id(const bool *stop_flags, const int *seq_lens_encoder, const int *seq_lens_decoder, const int64_t *step_idx, + const int *decode_states, int bs, int length, int length_input_ids) { int tid = threadIdx.x; if (tid < bs && !stop_flags[tid]) { + if (decode_states) { + // just deal text mode + if (decode_states[tid] != 0) return; + } int64_t *pre_ids_all_now = pre_ids_all + tid * length; const int64_t *input_ids_now = input_ids + tid * length_input_ids; const int seq_len_dec = seq_lens_decoder[tid]; @@ -51,7 +56,8 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &step_idx, - const paddle::Tensor &stop_flags) { + const paddle::Tensor &stop_flags, + const paddle::optional &decode_states) { #ifdef PADDLE_WITH_CUSTOM_DEVICE auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(stop_flags.place())); auto cu_stream = dev_ctx->stream(); @@ -71,6 +77,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, seq_lens_encoder.data(), seq_lens_decoder.data(), step_idx.data(), + decode_states ? decode_states.get().data() : nullptr, bs, length, length_input_ids); @@ -83,7 +90,8 @@ PD_BUILD_STATIC_OP(set_value_by_flags_and_idx) "seq_lens_encoder", "seq_lens_decoder", "step_idx", - "stop_flags"}) + "stop_flags", + paddle::Optional("decode_states")}) .Outputs({"pre_ids_all_out"}) .SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}}) .SetKernelFn(PD_KERNEL(SetValueByFlagsAndIdx)); diff --git a/custom_ops/gpu_ops/stop_generation_multi_ends.cu b/custom_ops/gpu_ops/stop_generation_multi_ends.cu index fe82be207f..3f2c6aebdf 100644 --- a/custom_ops/gpu_ops/stop_generation_multi_ends.cu +++ b/custom_ops/gpu_ops/stop_generation_multi_ends.cu @@ -28,6 +28,7 @@ __global__ void set_value_by_flags(bool *stop_flags, int64_t *next_tokens, const int64_t *end_ids, const int *seq_lens, + const int *src_batch_ids, const int bs, const int end_length, const int64_t *pre_ids, @@ -41,37 +42,42 @@ __global__ void set_value_by_flags(bool *stop_flags, bool prefill_one_step_stop) { int tid = threadIdx.x; int bid = blockIdx.x; + int ori_bid = bid; if (tid >= stop_seqs_bs) return; if (bid < bs) { + if (src_batch_ids) { + // mapping to original batch id + ori_bid = src_batch_ids[bid]; + } if(tid == 0){ if (prefill_one_step_stop) { - stop_flags[bid] = true; - if (seq_lens[bid] == 0) { + stop_flags[ori_bid] = true; + if (seq_lens[ori_bid] == 0) { topk_ids[bid] = -1; } - next_tokens[bid] = topk_ids[bid]; + next_tokens[ori_bid] = topk_ids[bid]; } else { - if (stop_flags[bid]) { - if (seq_lens[bid] == 0) { + if (stop_flags[ori_bid]) { + if (seq_lens[ori_bid] == 0) { topk_ids[bid] = -1; } else { topk_ids[bid] = end_ids[0]; - next_tokens[bid] = end_ids[0]; + next_tokens[ori_bid] = end_ids[0]; } } else { - next_tokens[bid] = topk_ids[bid]; + next_tokens[ori_bid] = topk_ids[bid]; } } if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { - stop_flags[bid] = true; + stop_flags[ori_bid] = true; } } // dealing stop_seqs - const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; + const int stop_seq_len = (stop_seqs_len + ori_bid * stop_seqs_bs)[tid]; if (stop_seq_len <= 0) return; - const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; - const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; - const int64_t step_idx_now = step_idx[bid]; + const int64_t *stop_seq_now = stop_seqs + ori_bid * stop_seqs_bs * stop_seqs_max_len + tid * stop_seqs_max_len; + const int64_t *pre_ids_now = pre_ids + ori_bid * pre_ids_len; + const int64_t step_idx_now = step_idx[ori_bid]; bool is_end = true; int count = 1; @@ -83,8 +89,8 @@ __global__ void set_value_by_flags(bool *stop_flags, } } if (is_end) { - next_tokens[bid] = end_ids[0]; - stop_flags[bid] = true; + next_tokens[ori_bid] = end_ids[0]; + stop_flags[ori_bid] = true; topk_ids[bid] = end_ids[0]; } } @@ -99,6 +105,7 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &step_idx, const paddle::Tensor &stop_seqs, const paddle::Tensor &stop_seqs_len, + const paddle::optional &src_batch_ids, const bool beam_search) { PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); @@ -128,6 +135,7 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const_cast(next_tokens.data()), end_ids.data(), seq_lens.data(), + src_batch_ids ? src_batch_ids.get().data() : nullptr, bs_now, end_length, pre_ids.data(), @@ -142,7 +150,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, } PD_BUILD_STATIC_OP(set_stop_value_multi_ends) - .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"}) + .Inputs({ + "topk_ids", + "stop_flags", + "seq_lens", + "end_ids", + "next_tokens", + "pre_ids", + "step_idx", + "stop_seqs", + "stop_seqs_len", + paddle::Optional("src_batch_ids")}) .Attrs({"beam_search: bool"}) .Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"}) .SetInplaceMap({{"topk_ids", "topk_ids_out"}, diff --git a/custom_ops/gpu_ops/token_penalty_multi_scores.cu b/custom_ops/gpu_ops/token_penalty_multi_scores.cu index 7db52f38af..c5eefbe6f5 100644 --- a/custom_ops/gpu_ops/token_penalty_multi_scores.cu +++ b/custom_ops/gpu_ops/token_penalty_multi_scores.cu @@ -19,15 +19,19 @@ __global__ inline void min_length_logits_process(T *logits, const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id, + const int *src_batch_ids, const int64_t bs, const int64_t vocab_size, const int64_t eos_len) { int bi = threadIdx.x; if (bi >= bs) return; - if (cur_len[bi] < 0) { + // mapping to ori batch_id + int ori_bi = bi; + if (src_batch_ids) ori_bi = src_batch_ids[bi]; + if (cur_len[ori_bi] < 0) { return; } - if (cur_len[bi] < min_len[bi]) { + if (cur_len[ori_bi] < min_len[ori_bi]) { for (int i = 0; i < eos_len; i++) { logits[bi * vocab_size + eos_token_id[i]] = -1e10; } @@ -40,15 +44,19 @@ __global__ inline void min_length_logits_process( const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id, + const int *src_batch_ids, const int64_t bs, const int64_t vocab_size, const int64_t eos_len) { int bi = threadIdx.x; if (bi >= bs) return; - if (cur_len[bi] < 0) { + // mapping to ori batch_id + int ori_bi = bi; + if (src_batch_ids) ori_bi = src_batch_ids[bi]; + if (cur_len[ori_bi] < 0) { return; } - if (cur_len[bi] < min_len[bi]) { + if (cur_len[ori_bi] < min_len[ori_bi]) { for (int i = 0; i < eos_len; i++) { logits[bi * vocab_size + eos_token_id[i]] = -1e4; } @@ -59,6 +67,7 @@ __global__ void update_repeat_times(const int64_t *pre_ids, const int64_t *prompt_ids, const int64_t *prompt_len, const int64_t *cur_len, + const int *stc_batch_ids, int *repeat_times, int *is_repeated, const int64_t bs, @@ -66,13 +75,15 @@ __global__ void update_repeat_times(const int64_t *pre_ids, const int64_t max_dec_len, const int64_t max_model_len) { int64_t bi = blockIdx.x; - if (cur_len[bi] < 0) { + int ori_bi = bi; + if (stc_batch_ids) ori_bi = stc_batch_ids[bi]; + if (cur_len[ori_bi] < 0) { return; } - const int64_t prompt_len_now = prompt_len[bi]; + const int64_t prompt_len_now = prompt_len[ori_bi]; int64_t tid = threadIdx.x; - const int64_t *prompt_now = prompt_ids + bi * max_model_len; - const int64_t *pre_ids_now = pre_ids + bi * max_dec_len; + const int64_t *prompt_now = prompt_ids + ori_bi * max_model_len; + const int64_t *pre_ids_now = pre_ids + ori_bi * max_dec_len; int *repeat_times_now = repeat_times + bi * vocab_size; int *is_repeated_now = is_repeated + bi * vocab_size; const int64_t loop_len = prompt_len_now > max_dec_len ? prompt_len_now : max_dec_len; @@ -100,17 +111,20 @@ __global__ void update_value_by_repeat_times(const int *repeat_times, const T *frequency_score, const T *presence_score, const float *temperatures, + const int *stc_batch_ids, T *logits, const int64_t bs, const int64_t vocab_size) { int bi = blockIdx.x; + int ori_bi = bi; + if (stc_batch_ids) ori_bi = stc_batch_ids[bi]; int tid = threadIdx.x; T *logits_now = logits + bi * vocab_size; const int *repeat_times_now = repeat_times + bi * vocab_size; const int *is_repeated_now = is_repeated + bi * vocab_size; - float alpha = static_cast(penalty_scores[bi]); - float beta = static_cast(frequency_score[bi]); - float gamma = static_cast(presence_score[bi]); + float alpha = static_cast(penalty_scores[ori_bi]); + float beta = static_cast(frequency_score[ori_bi]); + float gamma = static_cast(presence_score[ori_bi]); for (int i = tid; i < vocab_size; i += blockDim.x) { int times = repeat_times_now[i]; float logit_now = static_cast(logits_now[i]); @@ -120,7 +134,7 @@ __global__ void update_value_by_repeat_times(const int *repeat_times, if (times != 0) { logit_now = logit_now - times * beta - gamma; } - logits_now[i] = static_cast(logit_now / temperatures[bi]); + logits_now[i] = static_cast(logit_now / temperatures[ori_bi]); } } @@ -152,7 +166,8 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, const paddle::Tensor &bad_tokens, const paddle::Tensor &cur_len, const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id) { + const paddle::Tensor &eos_token_id, + const paddle::optional &src_batch_ids) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -182,6 +197,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, cur_len.data(), min_len.data(), eos_token_id.data(), + src_batch_ids ? src_batch_ids.get().data() : nullptr, bs, vocab_size, eos_len); @@ -197,6 +213,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, prompt_ids.data(), prompt_len.data(), cur_len.data(), + src_batch_ids ? src_batch_ids.get().data() : nullptr, repeat_times.data(), is_repeated.data(), bs, @@ -220,6 +237,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, reinterpret_cast( const_cast(presence_score.data())), temperatures.data(), + src_batch_ids ? src_batch_ids.get().data() : nullptr, reinterpret_cast( const_cast(logits.data())), bs, @@ -251,7 +269,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, const paddle::Tensor &bad_tokens, const paddle::Tensor &cur_len, const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id) { + const paddle::Tensor &eos_token_id, + const paddle::optional &src_batch_ids) { switch (logits.type()) { case paddle::DataType::BFLOAT16: { return token_penalty_multi_scores_kernel< @@ -266,7 +285,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, bad_tokens, cur_len, min_len, - eos_token_id); + eos_token_id, + src_batch_ids); } case paddle::DataType::FLOAT16: { return token_penalty_multi_scores_kernel< @@ -281,7 +301,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, bad_tokens, cur_len, min_len, - eos_token_id); + eos_token_id, + src_batch_ids); } case paddle::DataType::FLOAT32: { return token_penalty_multi_scores_kernel< @@ -296,7 +317,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, bad_tokens, cur_len, min_len, - eos_token_id); + eos_token_id, + src_batch_ids); } default: { PD_THROW( @@ -319,7 +341,8 @@ PD_BUILD_STATIC_OP(get_token_penalty_multi_scores) "bad_tokens", "cur_len", "min_len", - "eos_token_id"}) + "eos_token_id", + paddle::Optional("src_batch_ids")}) .Outputs({"logits_out"}) .SetInplaceMap({{"logits", "logits_out"}}) .SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));