Skip to content

supports bid mapping #3281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags);
const paddle::Tensor &stop_flags,
const paddle::optional<paddle::Tensor> &decode_states);

paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &tmp_out, // [token_num, dim_embed]
Expand All @@ -279,6 +280,7 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::optional<paddle::Tensor> &src_batch_ids,
const bool beam_search);


Expand Down
12 changes: 10 additions & 2 deletions custom_ops/gpu_ops/set_value_by_flags.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@ __global__ void set_value_by_flag_and_id(const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
const int *decode_states,
int bs,
int length,
int length_input_ids) {
int tid = threadIdx.x;
if (tid < bs && !stop_flags[tid]) {
if (decode_states) {
// just deal text mode
if (decode_states[tid] != 0) return;
}
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
const int64_t *input_ids_now = input_ids + tid * length_input_ids;
const int seq_len_dec = seq_lens_decoder[tid];
Expand All @@ -51,7 +56,8 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags) {
const paddle::Tensor &stop_flags,
const paddle::optional<paddle::Tensor> &decode_states) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(stop_flags.place()));
auto cu_stream = dev_ctx->stream();
Expand All @@ -71,6 +77,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
step_idx.data<int64_t>(),
decode_states ? decode_states.get().data<int>() : nullptr,
bs,
length,
length_input_ids);
Expand All @@ -83,7 +90,8 @@ PD_BUILD_STATIC_OP(set_value_by_flags_and_idx)
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"stop_flags"})
"stop_flags",
paddle::Optional("decode_states")})
.Outputs({"pre_ids_all_out"})
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
.SetKernelFn(PD_KERNEL(SetValueByFlagsAndIdx));
48 changes: 33 additions & 15 deletions custom_ops/gpu_ops/stop_generation_multi_ends.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ __global__ void set_value_by_flags(bool *stop_flags,
int64_t *next_tokens,
const int64_t *end_ids,
const int *seq_lens,
const int *src_batch_ids,
const int bs,
const int end_length,
const int64_t *pre_ids,
Expand All @@ -41,37 +42,42 @@ __global__ void set_value_by_flags(bool *stop_flags,
bool prefill_one_step_stop) {
int tid = threadIdx.x;
int bid = blockIdx.x;
int ori_bid = bid;
if (tid >= stop_seqs_bs) return;
if (bid < bs) {
if (src_batch_ids) {
// mapping to original batch id
ori_bid = src_batch_ids[bid];
}
if(tid == 0){
if (prefill_one_step_stop) {
stop_flags[bid] = true;
if (seq_lens[bid] == 0) {
stop_flags[ori_bid] = true;
if (seq_lens[ori_bid] == 0) {
topk_ids[bid] = -1;
}
next_tokens[bid] = topk_ids[bid];
next_tokens[ori_bid] = topk_ids[bid];
} else {
if (stop_flags[bid]) {
if (seq_lens[bid] == 0) {
if (stop_flags[ori_bid]) {
if (seq_lens[ori_bid] == 0) {
topk_ids[bid] = -1;
} else {
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
next_tokens[ori_bid] = end_ids[0];
}
} else {
next_tokens[bid] = topk_ids[bid];
next_tokens[ori_bid] = topk_ids[bid];
}
}
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
stop_flags[ori_bid] = true;
}
}
// dealing stop_seqs
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
const int stop_seq_len = (stop_seqs_len + ori_bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;
const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];
const int64_t *stop_seq_now = stop_seqs + ori_bid * stop_seqs_bs * stop_seqs_max_len + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + ori_bid * pre_ids_len;
const int64_t step_idx_now = step_idx[ori_bid];

bool is_end = true;
int count = 1;
Expand All @@ -83,8 +89,8 @@ __global__ void set_value_by_flags(bool *stop_flags,
}
}
if (is_end) {
next_tokens[bid] = end_ids[0];
stop_flags[bid] = true;
next_tokens[ori_bid] = end_ids[0];
stop_flags[ori_bid] = true;
topk_ids[bid] = end_ids[0];
}
}
Expand All @@ -99,6 +105,7 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::optional<paddle::Tensor> &src_batch_ids,
const bool beam_search) {
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
Expand Down Expand Up @@ -128,6 +135,7 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
src_batch_ids ? src_batch_ids.get().data<int>() : nullptr,
bs_now,
end_length,
pre_ids.data<int64_t>(),
Expand All @@ -142,7 +150,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
}

PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
.Inputs({
"topk_ids",
"stop_flags",
"seq_lens",
"end_ids",
"next_tokens",
"pre_ids",
"step_idx",
"stop_seqs",
"stop_seqs_len",
paddle::Optional("src_batch_ids")})
.Attrs({"beam_search: bool"})
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
.SetInplaceMap({{"topk_ids", "topk_ids_out"},
Expand Down
59 changes: 41 additions & 18 deletions custom_ops/gpu_ops/token_penalty_multi_scores.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@ __global__ inline void min_length_logits_process(T *logits,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int *src_batch_ids,
const int64_t bs,
const int64_t vocab_size,
const int64_t eos_len) {
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
// mapping to ori batch_id
int ori_bi = bi;
if (src_batch_ids) ori_bi = src_batch_ids[bi];
if (cur_len[ori_bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
if (cur_len[ori_bi] < min_len[ori_bi]) {
for (int i = 0; i < eos_len; i++) {
logits[bi * vocab_size + eos_token_id[i]] = -1e10;
}
Expand All @@ -40,15 +44,19 @@ __global__ inline void min_length_logits_process<half>(
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int *src_batch_ids,
const int64_t bs,
const int64_t vocab_size,
const int64_t eos_len) {
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
// mapping to ori batch_id
int ori_bi = bi;
if (src_batch_ids) ori_bi = src_batch_ids[bi];
if (cur_len[ori_bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
if (cur_len[ori_bi] < min_len[ori_bi]) {
for (int i = 0; i < eos_len; i++) {
logits[bi * vocab_size + eos_token_id[i]] = -1e4;
}
Expand All @@ -59,20 +67,23 @@ __global__ void update_repeat_times(const int64_t *pre_ids,
const int64_t *prompt_ids,
const int64_t *prompt_len,
const int64_t *cur_len,
const int *stc_batch_ids,
int *repeat_times,
int *is_repeated,
const int64_t bs,
const int64_t vocab_size,
const int64_t max_dec_len,
const int64_t max_model_len) {
int64_t bi = blockIdx.x;
if (cur_len[bi] < 0) {
int ori_bi = bi;
if (stc_batch_ids) ori_bi = stc_batch_ids[bi];
if (cur_len[ori_bi] < 0) {
return;
}
const int64_t prompt_len_now = prompt_len[bi];
const int64_t prompt_len_now = prompt_len[ori_bi];
int64_t tid = threadIdx.x;
const int64_t *prompt_now = prompt_ids + bi * max_model_len;
const int64_t *pre_ids_now = pre_ids + bi * max_dec_len;
const int64_t *prompt_now = prompt_ids + ori_bi * max_model_len;
const int64_t *pre_ids_now = pre_ids + ori_bi * max_dec_len;
int *repeat_times_now = repeat_times + bi * vocab_size;
int *is_repeated_now = is_repeated + bi * vocab_size;
const int64_t loop_len = prompt_len_now > max_dec_len ? prompt_len_now : max_dec_len;
Expand Down Expand Up @@ -100,17 +111,20 @@ __global__ void update_value_by_repeat_times(const int *repeat_times,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
const int *stc_batch_ids,
T *logits,
const int64_t bs,
const int64_t vocab_size) {
int bi = blockIdx.x;
int ori_bi = bi;
if (stc_batch_ids) ori_bi = stc_batch_ids[bi];
int tid = threadIdx.x;
T *logits_now = logits + bi * vocab_size;
const int *repeat_times_now = repeat_times + bi * vocab_size;
const int *is_repeated_now = is_repeated + bi * vocab_size;
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
float alpha = static_cast<float>(penalty_scores[ori_bi]);
float beta = static_cast<float>(frequency_score[ori_bi]);
float gamma = static_cast<float>(presence_score[ori_bi]);
for (int i = tid; i < vocab_size; i += blockDim.x) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
Expand All @@ -120,7 +134,7 @@ __global__ void update_value_by_repeat_times(const int *repeat_times,
if (times != 0) {
logit_now = logit_now - times * beta - gamma;
}
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
logits_now[i] = static_cast<T>(logit_now / temperatures[ori_bi]);
}
}

Expand Down Expand Up @@ -152,7 +166,8 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id) {
const paddle::Tensor &eos_token_id,
const paddle::optional<paddle::Tensor> &src_batch_ids) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
Expand Down Expand Up @@ -182,6 +197,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
src_batch_ids ? src_batch_ids.get().data<int>() : nullptr,
bs,
vocab_size,
eos_len);
Expand All @@ -197,6 +213,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
prompt_ids.data<int64_t>(),
prompt_len.data<int64_t>(),
cur_len.data<int64_t>(),
src_batch_ids ? src_batch_ids.get().data<int>() : nullptr,
repeat_times.data<int>(),
is_repeated.data<int>(),
bs,
Expand All @@ -220,6 +237,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(presence_score.data<data_t>())),
temperatures.data<float>(),
src_batch_ids ? src_batch_ids.get().data<int>() : nullptr,
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
bs,
Expand Down Expand Up @@ -251,7 +269,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id) {
const paddle::Tensor &eos_token_id,
const paddle::optional<paddle::Tensor> &src_batch_ids) {
switch (logits.type()) {
case paddle::DataType::BFLOAT16: {
return token_penalty_multi_scores_kernel<
Expand All @@ -266,7 +285,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
bad_tokens,
cur_len,
min_len,
eos_token_id);
eos_token_id,
src_batch_ids);
}
case paddle::DataType::FLOAT16: {
return token_penalty_multi_scores_kernel<
Expand All @@ -281,7 +301,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
bad_tokens,
cur_len,
min_len,
eos_token_id);
eos_token_id,
src_batch_ids);
}
case paddle::DataType::FLOAT32: {
return token_penalty_multi_scores_kernel<
Expand All @@ -296,7 +317,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
bad_tokens,
cur_len,
min_len,
eos_token_id);
eos_token_id,
src_batch_ids);
}
default: {
PD_THROW(
Expand All @@ -319,7 +341,8 @@ PD_BUILD_STATIC_OP(get_token_penalty_multi_scores)
"bad_tokens",
"cur_len",
"min_len",
"eos_token_id"})
"eos_token_id",
paddle::Optional("src_batch_ids")})
.Outputs({"logits_out"})
.SetInplaceMap({{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));
Loading