Skip to content

Commit 370adad

Browse files
authored
Add unified transformer decoding beam search and sampling (#739)
* add unified transformer decoding beam search and sampling
1 parent f753bb1 commit 370adad

File tree

11 files changed

+2721
-11
lines changed

11 files changed

+2721
-11
lines changed

paddlenlp/ops/CMakeLists.txt

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ option(WITH_GPU "Compile with GPU/CPU, default use CPU."
2525
option(USE_TENSORRT "Compile with TensorRT." OFF)
2626
option(WITH_TRANSFORMER "Compile with Transformer" ON)
2727
option(WITH_GPT "Compile with GPT" OFF)
28+
option(WITH_UNIFIED "Compile with Unified Transformer" ON)
2829

2930
if(NOT WITH_GPU)
3031
message(FATAL_ERROR "Faster transformer custom op doesn't support CPU. Please add the flag -DWITH_GPU=ON to use GPU. ")
@@ -38,6 +39,10 @@ if(WITH_GPT)
3839
list(APPEND decoding_op_files fusion_gpt_op.cc fusion_gpt_op.cu)
3940
endif()
4041

42+
if(WITH_UNIFIED)
43+
list(APPEND decoding_op_files fusion_unified_decoding_op.cc fusion_unified_decoding_op.cu)
44+
endif()
45+
4146
if(NOT WITH_TRANSFORMER AND NOT WITH_GPT)
4247
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON must be set to use FasterTransformer. ")
4348
endif()
@@ -124,15 +129,38 @@ file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/cuda/topk_kernel
124129
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/topk_kernels.cu topk_kernels_dst)
125130

126131
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/open_decoder.cu open_decoder_cu_dst)
127-
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/open_decoder.h open_decoder_header_dst)
132+
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/open_decoder.h open_decoder_h_dst)
133+
134+
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/cuda_kernels.h cuda_kernels_h_dst)
135+
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/decoding_kernels.cu decoding_kernels_cu_dst)
128136

129137
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/cuda/transformer_decoder.cu trans_decoder_cu_src)
130-
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/transformer_decoder.h trans_decoder_header_src)
138+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/transformer_decoder.h trans_decoder_h_src)
139+
140+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/cuda/transformer_cuda_kernels.h cuda_kernels_h_src)
141+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/cuda/transformer_decoding_kernels.cu decoding_kernels_cu_src)
142+
143+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/transformer_beamsearch.h beamsearch_h_src)
144+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/transformer_sampling.h sampling_h_src)
145+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/arguments.h arguments_h_src)
131146
set(trans_dst ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/)
132147

133148
# TODO(guosheng): `find` seems meeting errors missing argument to `-exec', fix it
134149
set(MUTE_COMMAND grep -rl "printf(\"\\[WARNING\\]" ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/ | xargs -i{} sed -i "s/printf(\"\\WWARNING\\W decoding[^)]\\{1,\\})/ /" {})
135-
set(FT_PATCH_COMMAND cp ${allocator_src} ${allocator_dst} && cp ${common_src} ${common_dst} && cp ${cmakelists_src} ${cmakelists_dst} && cp ${topk_kernels_src} ${topk_kernels_dst} && cat ${trans_decoder_cu_src} >> ${open_decoder_cu_dst} && cat ${trans_decoder_header_src} >> ${open_decoder_header_dst} && ${MUTE_COMMAND})
150+
set(FT_PATCH_COMMAND
151+
cp ${allocator_src} ${allocator_dst}
152+
&& cp ${common_src} ${common_dst}
153+
&& cp ${cmakelists_src} ${cmakelists_dst}
154+
&& cp ${topk_kernels_src} ${topk_kernels_dst}
155+
&& cp ${beamsearch_h_src} ${trans_dst}
156+
&& cp ${sampling_h_src} ${trans_dst}
157+
&& cp ${arguments_h_src} ${trans_dst}
158+
&& cat ${trans_decoder_cu_src} >> ${open_decoder_cu_dst}
159+
&& cat ${trans_decoder_h_src} >> ${open_decoder_h_dst}
160+
&& cat ${cuda_kernels_h_src} >> ${cuda_kernels_h_dst}
161+
&& cat ${decoding_kernels_cu_src} >> ${decoding_kernels_cu_dst}
162+
&& ${MUTE_COMMAND}
163+
)
136164

137165
######################################################################################
138166
# A function for automatic detection of GPUs installed (if autodetection is enabled)

paddlenlp/ops/faster_transformer/src/demo/gpt.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ bool get_result_tensor(const std::unique_ptr<paddle_infer::Tensor>& seq_ids,
8484

8585
for (int i = 0; i < tmp_result_q.length(); ++i) {
8686
char32_t tmp = tmp_result_q[i];
87-
// std::cout << tmp << std::endl;
8887
if (byte_decoder.find(tmp) != byte_decoder.end()) {
8988
dataresultvec[bsz].result_q = dataresultvec[bsz].result_q +
9089
static_cast<wchar_t>(byte_decoder[tmp]);
@@ -126,13 +125,6 @@ std::unordered_map<char32_t, int> convert_unicode() {
126125
}
127126
}
128127

129-
// std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv32;
130-
// for (int i=0; i<256; ++i) {
131-
// std::cout << "=====" << std::endl;
132-
// std::cout << conv32.to_bytes(cs[i]) << std::endl;
133-
// std::cout << bs[i] << std::endl;
134-
// }
135-
136128
return ret;
137129
}
138130

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include <string>
15+
#include <vector>
16+
17+
#include "fusion_unified_decoding_op.h"
18+
#include "pd_traits.h"
19+
20+
21+
std::vector<paddle::Tensor> UnifiedDecodingForward(
22+
const std::vector<paddle::Tensor>& cache_k,
23+
const std::vector<paddle::Tensor>& cache_v,
24+
const paddle::Tensor& mem_seq_len,
25+
const paddle::Tensor& logits_mask,
26+
const paddle::Tensor& word_embedding,
27+
const std::vector<paddle::Tensor>& self_ln_weight,
28+
const std::vector<paddle::Tensor>& self_ln_bias,
29+
const std::vector<paddle::Tensor>& self_q_weight,
30+
const std::vector<paddle::Tensor>& self_q_bias,
31+
const std::vector<paddle::Tensor>& self_k_weight,
32+
const std::vector<paddle::Tensor>& self_k_bias,
33+
const std::vector<paddle::Tensor>& self_v_weight,
34+
const std::vector<paddle::Tensor>& self_v_bias,
35+
const std::vector<paddle::Tensor>& self_out_weight,
36+
const std::vector<paddle::Tensor>& self_out_bias,
37+
const std::vector<paddle::Tensor>& ffn_ln_weight,
38+
const std::vector<paddle::Tensor>& ffn_ln_bias,
39+
const std::vector<paddle::Tensor>& ffn_inter_weight,
40+
const std::vector<paddle::Tensor>& ffn_inter_bias,
41+
const std::vector<paddle::Tensor>& ffn_out_weight,
42+
const std::vector<paddle::Tensor>& ffn_out_bias,
43+
const paddle::Tensor& decoder_ln_weight,
44+
const paddle::Tensor& decoder_ln_bias,
45+
const paddle::Tensor& trans_weight,
46+
const paddle::Tensor& trans_bias,
47+
const paddle::Tensor& lm_ln_weight,
48+
const paddle::Tensor& lm_ln_bias,
49+
const paddle::Tensor& embedding_weight,
50+
const paddle::Tensor& embedding_bias,
51+
const paddle::Tensor& positional_embedding_weight,
52+
const paddle::Tensor& type_embedding_weight,
53+
const std::string& decoding_strategy,
54+
const int& beam_size,
55+
const int& topk,
56+
const float& topp,
57+
const int& n_head,
58+
const int& size_per_head,
59+
const int& num_layer,
60+
const int& bos_id,
61+
const int& eos_id,
62+
const int64_t& max_len,
63+
const float& beam_search_diversity_rate,
64+
const int& type_id,
65+
const int& unk_id,
66+
const int& mask_id,
67+
const float& temperature,
68+
const float& len_penalty) {
69+
int batch_size = cache_k[0].shape()[0];
70+
71+
std::vector<int64_t> output_dims;
72+
std::vector<int64_t> parent_ids_dims;
73+
std::vector<int64_t> sequence_length_dims({batch_size});
74+
if (decoding_strategy == "beam_search") {
75+
if (batch_size != -1) {
76+
batch_size /= beam_size;
77+
}
78+
output_dims = {max_len, batch_size, beam_size};
79+
parent_ids_dims = output_dims;
80+
} else if (decoding_strategy == "topk_sampling" ||
81+
decoding_strategy == "topp_sampling") {
82+
output_dims = {max_len, batch_size};
83+
parent_ids_dims = {1};
84+
} else {
85+
PD_THROW("Not supported decoding strategy. ");
86+
}
87+
auto output_ids = paddle::Tensor(cache_k[0].place(), output_dims);
88+
auto parent_ids = paddle::Tensor(cache_k[0].place(), parent_ids_dims);
89+
auto sequence_length =
90+
paddle::Tensor(cache_k[0].place(), sequence_length_dims);
91+
92+
if (cache_k[0].place() == paddle::PlaceType::kGPU) {
93+
auto sequence_length = paddle::Tensor(paddle::PlaceType::kGPU);
94+
95+
if (mem_seq_len.place() != paddle::PlaceType::kGPU) {
96+
sequence_length = mem_seq_len.copy_to<int>(paddle::PlaceType::kGPU);
97+
} else {
98+
sequence_length = mem_seq_len;
99+
}
100+
101+
return UnifiedDecodingCUDAForward(cache_k,
102+
cache_v,
103+
sequence_length,
104+
logits_mask,
105+
word_embedding,
106+
self_ln_weight,
107+
self_ln_bias,
108+
self_q_weight,
109+
self_q_bias,
110+
self_k_weight,
111+
self_k_bias,
112+
self_v_weight,
113+
self_v_bias,
114+
self_out_weight,
115+
self_out_bias,
116+
ffn_ln_weight,
117+
ffn_ln_bias,
118+
ffn_inter_weight,
119+
ffn_inter_bias,
120+
ffn_out_weight,
121+
ffn_out_bias,
122+
decoder_ln_weight,
123+
decoder_ln_bias,
124+
trans_weight,
125+
trans_bias,
126+
lm_ln_weight,
127+
lm_ln_bias,
128+
embedding_weight,
129+
embedding_bias,
130+
positional_embedding_weight,
131+
type_embedding_weight,
132+
output_ids,
133+
parent_ids,
134+
sequence_length,
135+
decoding_strategy,
136+
beam_size,
137+
topk,
138+
topp,
139+
n_head,
140+
size_per_head,
141+
num_layer,
142+
bos_id,
143+
eos_id,
144+
max_len,
145+
beam_search_diversity_rate,
146+
type_id,
147+
unk_id,
148+
mask_id,
149+
temperature,
150+
len_penalty);
151+
} else {
152+
PD_THROW("Not implemented place. Only GPU is supported. ");
153+
}
154+
}
155+
156+
std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
157+
const std::vector<std::vector<int64_t>>& cache_k_shapes,
158+
const std::vector<std::vector<int64_t>>& cache_v_shapes,
159+
const std::vector<int64_t>& mem_seq_len_shape,
160+
const std::vector<int64_t>& logits_mask_shape,
161+
const std::vector<int64_t>& word_embedding_shape,
162+
const std::vector<std::vector<int64_t>>& self_ln_weight_shapes,
163+
const std::vector<std::vector<int64_t>>& self_ln_bias_shapes,
164+
const std::vector<std::vector<int64_t>>& self_q_weight_shapes,
165+
const std::vector<std::vector<int64_t>>& self_q_bias_shapes,
166+
const std::vector<std::vector<int64_t>>& self_k_weight_shapes,
167+
const std::vector<std::vector<int64_t>>& self_k_bias_shapes,
168+
const std::vector<std::vector<int64_t>>& self_v_weight_shapes,
169+
const std::vector<std::vector<int64_t>>& self_v_bias_shapes,
170+
const std::vector<std::vector<int64_t>>& self_out_weight_shapes,
171+
const std::vector<std::vector<int64_t>>& self_out_bias_shapes,
172+
const std::vector<std::vector<int64_t>>& ffn_ln_weight_shapes,
173+
const std::vector<std::vector<int64_t>>& ffn_ln_bias_shapes,
174+
const std::vector<std::vector<int64_t>>& ffn_inter_weight_shapes,
175+
const std::vector<std::vector<int64_t>>& ffn_inter_bias_shapes,
176+
const std::vector<std::vector<int64_t>>& ffn_out_weight_shapes,
177+
const std::vector<std::vector<int64_t>>& ffn_out_bias_shapes,
178+
const std::vector<int64_t>& decoder_ln_weight_shape,
179+
const std::vector<int64_t>& decoder_ln_bias_shape,
180+
const std::vector<int64_t>& trans_weight_shape,
181+
const std::vector<int64_t>& trans_bias_shape,
182+
const std::vector<int64_t>& lm_ln_weight_shape,
183+
const std::vector<int64_t>& lm_ln_bias_shape,
184+
const std::vector<int64_t>& embedding_weight_shape,
185+
const std::vector<int64_t>& embedding_bias_shape,
186+
const std::vector<int64_t>& positional_embedding_weight_shape,
187+
const std::vector<int64_t>& type_embedding_weight_shape,
188+
const std::string& decoding_strategy,
189+
const int& beam_size,
190+
const int& topk,
191+
const float& topp,
192+
const int& n_head,
193+
const int& size_per_head,
194+
const int& num_layer,
195+
const int& bos_id,
196+
const int& eos_id,
197+
const int64_t& max_len,
198+
const float& beam_search_diversity_rate,
199+
const int& type_id,
200+
const int& unk_id,
201+
const int& mask_id,
202+
const float& temperature,
203+
const float& len_penalty) {
204+
int batch_size = cache_k_shapes[0][0];
205+
206+
std::vector<int64_t> output_dims;
207+
std::vector<int64_t> sequence_length_dims({batch_size});
208+
if (decoding_strategy == "beam_search") {
209+
if (batch_size != -1) {
210+
batch_size /= beam_size;
211+
}
212+
output_dims = {max_len, batch_size, beam_size};
213+
return {output_dims, output_dims, sequence_length_dims};
214+
} else if (decoding_strategy == "topk_sampling" ||
215+
decoding_strategy == "topp_sampling") {
216+
output_dims = {max_len, batch_size};
217+
return {output_dims, {1}, sequence_length_dims};
218+
} else {
219+
PD_THROW("Not supported decoding strategy. ");
220+
}
221+
}
222+
223+
std::vector<paddle::DataType> UnifiedDecodingInferDtype(
224+
const std::vector<paddle::DataType>& cache_k,
225+
const std::vector<paddle::DataType>& cache_v,
226+
const paddle::DataType& mem_seq_len,
227+
const paddle::DataType& logits_mask,
228+
const paddle::DataType& word_embedding,
229+
const std::vector<paddle::DataType>& self_ln_weight,
230+
const std::vector<paddle::DataType>& self_ln_bias,
231+
const std::vector<paddle::DataType>& self_q_weight,
232+
const std::vector<paddle::DataType>& self_q_bias,
233+
const std::vector<paddle::DataType>& self_k_weight,
234+
const std::vector<paddle::DataType>& self_k_bias,
235+
const std::vector<paddle::DataType>& self_v_weight,
236+
const std::vector<paddle::DataType>& self_v_bias,
237+
const std::vector<paddle::DataType>& self_out_weight,
238+
const std::vector<paddle::DataType>& self_out_bias,
239+
const std::vector<paddle::DataType>& ffn_ln_weight,
240+
const std::vector<paddle::DataType>& ffn_ln_bias,
241+
const std::vector<paddle::DataType>& ffn_inter_weight,
242+
const std::vector<paddle::DataType>& ffn_inter_bias,
243+
const std::vector<paddle::DataType>& ffn_out_weight,
244+
const std::vector<paddle::DataType>& ffn_out_bias,
245+
const paddle::DataType& decoder_ln_weight,
246+
const paddle::DataType& decoder_ln_bias,
247+
const paddle::DataType& trans_weight,
248+
const paddle::DataType& trans_bias,
249+
const paddle::DataType& lm_ln_weight,
250+
const paddle::DataType& lm_ln_bias,
251+
const paddle::DataType& embedding_weight,
252+
const paddle::DataType& embedding_bias,
253+
const paddle::DataType& positional_embedding_weight,
254+
const paddle::DataType& type_embedding_weight) {
255+
return {paddle::DataType::INT32,
256+
paddle::DataType::INT32,
257+
paddle::DataType::INT32};
258+
}
259+
260+
PD_BUILD_OP(fusion_unified_decoding)
261+
.Inputs({paddle::Vec("CacheK"),
262+
paddle::Vec("CacheV"),
263+
"MemSeqLen",
264+
"LogitsMask",
265+
"WordEmbedding",
266+
paddle::Vec("SelfLayernormWeight"),
267+
paddle::Vec("SelfLayernormBias"),
268+
paddle::Vec("SelfQueryWeight"),
269+
paddle::Vec("SelfQueryBias"),
270+
paddle::Vec("SelfKeyWeight"),
271+
paddle::Vec("SelfKeyBias"),
272+
paddle::Vec("SelfValueWeight"),
273+
paddle::Vec("SelfValueBias"),
274+
paddle::Vec("SelfOutWeight"),
275+
paddle::Vec("SelfOutBias"),
276+
paddle::Vec("FFNLayernormWeight"),
277+
paddle::Vec("FFNLayernormBias"),
278+
paddle::Vec("FFNInterWeight"),
279+
paddle::Vec("FFNInterBias"),
280+
paddle::Vec("FFNOutWeight"),
281+
paddle::Vec("FFNOutBias"),
282+
"DecoderLayernormWeight",
283+
"DecoderLayernormBias",
284+
"TransWeight",
285+
"TransBias",
286+
"LMLayernormWeight",
287+
"LMLayernormBias",
288+
"EmbWeight",
289+
"EmbBias",
290+
"PositionEncEmb",
291+
"TypeEmb"})
292+
.Outputs({"OutputIds", "ParentIds", "SequenceLength"})
293+
.Attrs({"decoding_strategy: std::string",
294+
"beam_size: int",
295+
"topk: int",
296+
"topp: float",
297+
"n_head: int",
298+
"size_per_head: int",
299+
"num_layer: int",
300+
"bos_id: int",
301+
"eos_id: int",
302+
"max_len: int64_t",
303+
"beam_search_diversity_rate: float",
304+
"type_id: int",
305+
"unk_id: int",
306+
"mask_id: int",
307+
"temperature: float",
308+
"len_penalty: float"})
309+
.SetKernelFn(PD_KERNEL(UnifiedDecodingForward))
310+
.SetInferShapeFn(PD_INFER_SHAPE(UnifiedDecodingInferShape))
311+
.SetInferDtypeFn(PD_INFER_DTYPE(UnifiedDecodingInferDtype));

0 commit comments

Comments
 (0)