Skip to content

Commit 5855111

Browse files
smallv0221LiuChiachiFrostML
authored
Add general k support for topk sampling (#792)
* fix unified transformer dtype problem * fix win dtype bug * Fix plato-2 and plato-mini dtype bug * Fix plato-2 tokenization * Refine some doc * Add general k support for topk sampling * fix seed * minor fix * Fix unitransformer readme Co-authored-by: Jiaqi Liu <[email protected]> Co-authored-by: liu zhengxi <[email protected]>
1 parent ac3a616 commit 5855111

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

examples/dialogue/unified_transformer/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ python -m paddle.distributed.launch --gpus '0' --log_dir ./log finetune.py \
8181
|---------------------------------|
8282
| unified_transformer-12L-cn |
8383
| unified_transformer-12L-cn-luge |
84-
| plato-mini |
8584

8685
- `save_dir` 表示模型的保存路径。
8786
- `logging_steps` 表示日志打印间隔。
@@ -143,7 +142,6 @@ python infer.py \
143142
|---------------------------------|
144143
| unified_transformer-12L-cn |
145144
| unified_transformer-12L-cn-luge |
146-
| plato-mini |
147145

148146
- `output_path` 表示预测结果的保存路径。
149147
- `logging_steps` 表示日志打印间隔。

paddlenlp/ops/patches/FasterTransformer/cuda/topk_kernels.cu

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,48 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
6565
}
6666
}
6767

68+
template <typename T, int THREADBLOCK_SIZE>
69+
__global__ void beam_topK_kernel_general(const T* log_probs,
70+
T* tmp_log_probs,
71+
int* topk_tmp_id_buf,
72+
T* topk_tmp_val_buf,
73+
const int k,
74+
const int vocab_size) {
75+
const bool IS_FP16 = std::is_same<T, half>::value;
76+
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
77+
typedef cub::BlockReduce<TopK_2<T>, THREADBLOCK_SIZE> BlockReduce;
78+
__shared__ typename BlockReduce::TempStorage temp_storage;
79+
80+
const int tid = threadIdx.x;
81+
const int bid = blockIdx.x;
82+
TopK_2<T> partial;
83+
84+
for (int elem_id = tid; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {
85+
int index = elem_id + bid * vocab_size;
86+
tmp_log_probs[index] = log_probs[index];
87+
}
88+
89+
for (int ite = 0; ite < k; ite++) {
90+
partial.init();
91+
#pragma unroll
92+
for (int elem_id = tid; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {
93+
int index = elem_id + bid * vocab_size;
94+
partial.insert(tmp_log_probs[index], index);
95+
}
96+
97+
TopK_2<T> total =
98+
BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);
99+
100+
if (tid == 0) {
101+
const int index = bid * k + ite;
102+
topk_tmp_id_buf[index] = total.p;
103+
topk_tmp_val_buf[index] = total.u;
104+
tmp_log_probs[total.p] = -MAX_T_VAL;
105+
}
106+
__syncthreads();
107+
}
108+
}
109+
68110
#define CASE_K(K) \
69111
case K: \
70112
beam_topK_kernel<T, K, block_size><<<batch_size, block_size, 0, stream>>>( \
@@ -554,7 +596,6 @@ void topK_sampling_kernel_kernelLauncher(void* workspace,
554596
cudaStream_t stream) {
555597
std::minstd_rand engine;
556598
int seed = std::random_device()();
557-
558599
const int batch_size = args.batch_size_;
559600
const int vocab_size = args.vocab_size_;
560601
const int candidate_num = args.candidate_num_;
@@ -563,25 +604,36 @@ void topK_sampling_kernel_kernelLauncher(void* workspace,
563604

564605
int topk_tmp_ids_buf_size =
565606
args.batch_size_ * args.candidate_num_; // type int
607+
int temp_log_probs_buf_size =
608+
args.batch_size_ * args.candidate_num_ * vocab_size;
566609
int topk_tmp_val_buf_size = args.batch_size_ * args.candidate_num_; // type T
610+
611+
temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4;
567612
topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4;
568613
topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4;
569614

570615
if (workspace == nullptr) {
571-
workspace_size = sizeof(int) * topk_tmp_ids_buf_size +
572-
sizeof(int) * topk_tmp_val_buf_size;
616+
workspace_size = sizeof(T) * temp_log_probs_buf_size +
617+
sizeof(int) * topk_tmp_ids_buf_size +
618+
sizeof(T) * topk_tmp_val_buf_size;
573619
} else {
574-
int* topk_tmp_id_buf = (int*)workspace;
620+
T* temp_log_probs = (T*)workspace;
621+
int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size);
575622
T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size);
576623

577624
switch (candidate_num) {
578625
CASE_K(1);
579626
CASE_K(2);
580627
CASE_K(4);
581628
default:
582-
printf("[ERROR] Topk kernel does not support candidate_num = %d \n",
583-
candidate_num);
584-
exit(0);
629+
beam_topK_kernel_general<
630+
T,
631+
block_size><<<batch_size, block_size, 0, stream>>>(log_probs,
632+
temp_log_probs,
633+
topk_tmp_id_buf,
634+
topk_tmp_val_buf,
635+
candidate_num,
636+
vocab_size);
585637
break;
586638
}
587639
sampling<T><<<batch_size, candidate_num, 0, stream>>>(topk_tmp_id_buf,

0 commit comments

Comments
 (0)