Skip to content

Commit 4a91065

Browse files
authored
FT support context forward acceleration (#1559)
* context acceleration support
1 parent 8801dc0 commit 4a91065

22 files changed

+1642
-969
lines changed

paddlenlp/ops/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransforme
261261
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/open_decoder.h open_decoder_h_dst)
262262

263263
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.h cuda_kernels_h_src)
264-
#file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/transformer_cuda_kernels.h trans_cuda_kernels_h_src)
265264
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/cuda_kernels.h cuda_kernels_h_dst)
266265

267266
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.cu cuda_kernels_cu_src)

paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.cc

Lines changed: 86 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ limitations under the License. */
1919

2020

2121
std::vector<paddle::Tensor> UnifiedDecodingForward(
22-
const std::vector<paddle::Tensor>& cache_k,
23-
const std::vector<paddle::Tensor>& cache_v,
22+
const paddle::Tensor& input_ids,
23+
const paddle::Tensor& attn_mask,
2424
const paddle::Tensor& mem_seq_len,
2525
const paddle::Tensor& type_id,
26+
const paddle::Tensor& decoder_type_id,
2627
const paddle::Tensor& logits_mask,
2728
const paddle::Tensor& word_embedding,
2829
const std::vector<paddle::Tensor>& self_ln_weight,
@@ -51,6 +52,11 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
5152
const paddle::Tensor& embedding_bias,
5253
const paddle::Tensor& positional_embedding_weight,
5354
const paddle::Tensor& type_embedding_weight,
55+
const paddle::Tensor& role_id,
56+
const paddle::Tensor& decoder_role_id,
57+
const paddle::Tensor& role_embedding_table,
58+
const paddle::Tensor& position_ids,
59+
const paddle::Tensor& decoder_position_ids,
5460
const std::string& decoding_strategy,
5561
const int& beam_size,
5662
const int& topk,
@@ -70,19 +76,22 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
7076
const bool& pos_bias,
7177
const std::string& hidden_act,
7278
const bool& rel_len,
73-
const bool& early_stopping) {
74-
int batch_size = cache_k[0].shape()[0];
75-
int max_out_len = rel_len ? max_len + cache_k[0].shape()[2] : max_len;
79+
const bool& early_stopping,
80+
const int& min_length) {
81+
int batch_size = input_ids.shape()[0];
82+
int max_out_len = rel_len ? max_len + input_ids.shape()[1] : max_len;
7683

77-
std::vector<int64_t> output_dims;
84+
std::vector<int64_t> output_ids_dims;
85+
std::vector<int64_t> output_scores_dims;
7886
std::vector<int64_t> parent_ids_dims;
7987
std::vector<int64_t> sequence_length_dims({batch_size});
8088
if (decoding_strategy == "beam_search") {
8189
if (batch_size != -1) {
8290
batch_size /= beam_size;
8391
}
84-
output_dims = {max_out_len, batch_size, beam_size};
85-
parent_ids_dims = output_dims;
92+
output_ids_dims = {max_out_len, batch_size, beam_size};
93+
output_scores_dims = {batch_size, beam_size};
94+
parent_ids_dims = output_ids_dims;
8695
} else if (decoding_strategy == "beam_search_v2" ||
8796
decoding_strategy == "beam_search_v3") {
8897
// Use separated alive and finish beam queues to avoid the decrease of alive
@@ -94,22 +103,25 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
94103
} else {
95104
sequence_length_dims = {batch_size};
96105
}
97-
output_dims = {max_out_len, batch_size, beam_size * 2};
98-
parent_ids_dims = output_dims;
106+
output_ids_dims = {max_out_len, batch_size, beam_size * 2};
107+
output_scores_dims = {batch_size, beam_size * 2};
108+
parent_ids_dims = output_ids_dims;
99109
} else if (decoding_strategy == "topk_sampling" ||
100110
decoding_strategy == "topp_sampling" ||
101111
decoding_strategy == "sampling") {
102-
output_dims = {max_out_len, batch_size};
112+
output_ids_dims = {max_out_len, batch_size};
113+
output_scores_dims = {batch_size};
103114
parent_ids_dims = {1};
104115
} else {
105116
PD_THROW("Not supported decoding strategy. ");
106117
}
107-
auto output_ids = paddle::Tensor(cache_k[0].place(), output_dims);
108-
auto parent_ids = paddle::Tensor(cache_k[0].place(), parent_ids_dims);
118+
auto output_ids = paddle::Tensor(input_ids.place(), output_ids_dims);
119+
auto parent_ids = paddle::Tensor(input_ids.place(), parent_ids_dims);
109120
auto sequence_length =
110-
paddle::Tensor(cache_k[0].place(), sequence_length_dims);
121+
paddle::Tensor(input_ids.place(), sequence_length_dims);
122+
auto output_scores = paddle::Tensor(input_ids.place(), output_scores_dims);
111123

112-
if (cache_k[0].place() == paddle::PlaceType::kGPU) {
124+
if (input_ids.place() == paddle::PlaceType::kGPU) {
113125
auto mem_seq_length = paddle::Tensor(paddle::PlaceType::kGPU);
114126

115127
if (mem_seq_len.place() != paddle::PlaceType::kGPU) {
@@ -118,10 +130,11 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
118130
mem_seq_length = mem_seq_len;
119131
}
120132

121-
return UnifiedDecodingCUDAForward(cache_k,
122-
cache_v,
133+
return UnifiedDecodingCUDAForward(input_ids,
134+
attn_mask,
123135
mem_seq_length,
124136
type_id,
137+
decoder_type_id,
125138
logits_mask,
126139
word_embedding,
127140
self_ln_weight,
@@ -150,9 +163,15 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
150163
embedding_bias,
151164
positional_embedding_weight,
152165
type_embedding_weight,
166+
role_id,
167+
decoder_role_id,
168+
role_embedding_table,
169+
position_ids,
170+
decoder_position_ids,
153171
output_ids,
154172
parent_ids,
155173
sequence_length,
174+
output_scores,
156175
decoding_strategy,
157176
beam_size,
158177
topk,
@@ -171,17 +190,20 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
171190
normalize_before,
172191
pos_bias,
173192
hidden_act,
174-
early_stopping);
193+
early_stopping,
194+
min_length);
175195
} else {
176196
PD_THROW("Not implemented place. Only GPU is supported. ");
177197
}
178198
}
179199

180200
std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
181-
const std::vector<std::vector<int64_t>>& cache_k_shapes,
182-
const std::vector<std::vector<int64_t>>& cache_v_shapes,
201+
const std::vector<int64_t>& input_ids_shape,
202+
const std::vector<int64_t>& attn_mask_shape,
183203
const std::vector<int64_t>& mem_seq_len_shape,
184204
const std::vector<int64_t>& logits_mask_shape,
205+
const std::vector<int64_t>& type_id_shape,
206+
const std::vector<int64_t>& decoder_type_id_shape,
185207
const std::vector<int64_t>& word_embedding_shape,
186208
const std::vector<std::vector<int64_t>>& self_ln_weight_shapes,
187209
const std::vector<std::vector<int64_t>>& self_ln_bias_shapes,
@@ -209,6 +231,11 @@ std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
209231
const std::vector<int64_t>& embedding_bias_shape,
210232
const std::vector<int64_t>& positional_embedding_weight_shape,
211233
const std::vector<int64_t>& type_embedding_weight_shape,
234+
const std::vector<int64_t>& role_id_shape,
235+
const std::vector<int64_t>& decoder_role_id_shape,
236+
const std::vector<int64_t>& role_embedding_table_shape,
237+
const std::vector<int64_t>& position_ids_shape,
238+
const std::vector<int64_t>& decoder_position_ids_shape,
212239
const std::string& decoding_strategy,
213240
const int& beam_size,
214241
const int& topk,
@@ -228,17 +255,20 @@ std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
228255
const bool& pos_bias,
229256
const std::string& hidden_act,
230257
const bool& rel_len,
231-
const bool& early_stopping) {
232-
int batch_size = cache_k_shapes[0][0];
258+
const bool& early_stopping,
259+
const int& min_length) {
260+
int batch_size = input_ids_shape[0];
233261

234-
std::vector<int64_t> output_dims;
262+
std::vector<int64_t> output_ids_dims;
263+
std::vector<int64_t> output_scores_dims;
235264
std::vector<int64_t> sequence_length_dims({batch_size});
236265
if (decoding_strategy == "beam_search") {
237266
if (batch_size != -1) {
238267
batch_size /= beam_size;
239268
}
240-
output_dims = {max_len, batch_size, beam_size};
241-
return {output_dims, output_dims, sequence_length_dims};
269+
output_ids_dims = {max_len, batch_size, beam_size};
270+
output_scores_dims = {batch_size, beam_size};
271+
return {output_ids_dims, output_ids_dims, sequence_length_dims, output_scores_dims};
242272
} else if (decoding_strategy == "beam_search_v2" ||
243273
decoding_strategy == "beam_search_v3") {
244274
// Use separated alive and finish beam queues to avoid the decrease of alive
@@ -250,23 +280,27 @@ std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
250280
} else {
251281
sequence_length_dims = {batch_size};
252282
}
253-
output_dims = {max_len, batch_size, beam_size * 2};
254-
return {output_dims, output_dims, sequence_length_dims};
283+
output_ids_dims = {max_len, batch_size, beam_size * 2};
284+
output_scores_dims = {batch_size, beam_size * 2};
285+
return {output_ids_dims, output_ids_dims, sequence_length_dims, output_scores_dims};
255286
} else if (decoding_strategy == "topk_sampling" ||
256287
decoding_strategy == "topp_sampling" ||
257288
decoding_strategy == "sampling") {
258-
output_dims = {max_len, batch_size};
259-
return {output_dims, {1}, sequence_length_dims};
289+
output_ids_dims = {max_len, batch_size};
290+
output_scores_dims = {batch_size};
291+
return {output_ids_dims, {1}, sequence_length_dims, output_scores_dims};
260292
} else {
261293
PD_THROW("Not supported decoding strategy. ");
262294
}
263295
}
264296

265297
std::vector<paddle::DataType> UnifiedDecodingInferDtype(
266-
const std::vector<paddle::DataType>& cache_k,
267-
const std::vector<paddle::DataType>& cache_v,
298+
const paddle::DataType& input_ids,
299+
const paddle::DataType& attn_mask,
268300
const paddle::DataType& mem_seq_len,
269301
const paddle::DataType& logits_mask,
302+
const paddle::DataType& type_id,
303+
const paddle::DataType& decoder_type_id,
270304
const paddle::DataType& word_embedding,
271305
const std::vector<paddle::DataType>& self_ln_weight,
272306
const std::vector<paddle::DataType>& self_ln_bias,
@@ -293,17 +327,24 @@ std::vector<paddle::DataType> UnifiedDecodingInferDtype(
293327
const paddle::DataType& embedding_weight,
294328
const paddle::DataType& embedding_bias,
295329
const paddle::DataType& positional_embedding_weight,
296-
const paddle::DataType& type_embedding_weight) {
330+
const paddle::DataType& type_embedding_weight,
331+
const paddle::DataType& role_id,
332+
const paddle::DataType& decoder_role_id,
333+
const paddle::DataType& role_embedding_table,
334+
const paddle::DataType& position_ids,
335+
const paddle::DataType& decoder_position_ids) {
297336
return {paddle::DataType::INT32,
298337
paddle::DataType::INT32,
299-
paddle::DataType::INT32};
338+
paddle::DataType::INT32,
339+
paddle::DataType::FLOAT32};
300340
}
301341

302342
PD_BUILD_OP(fusion_unified_decoding)
303-
.Inputs({paddle::Vec("CacheK"),
304-
paddle::Vec("CacheV"),
343+
.Inputs({"InputIds",
344+
"AttnMask",
305345
"MemSeqLen",
306-
"TypeId",
346+
"TypeIds",
347+
"DecTypeIds",
307348
"LogitsMask",
308349
"WordEmbedding",
309350
paddle::Vec("SelfLayernormWeight"),
@@ -331,8 +372,13 @@ PD_BUILD_OP(fusion_unified_decoding)
331372
"EmbWeight",
332373
"EmbBias",
333374
"PositionEncEmb",
334-
"TypeEmb"})
335-
.Outputs({"OutputIds", "ParentIds", "SequenceLength"})
375+
"TypeEmb",
376+
"RoleIds",
377+
"DecRoleIds",
378+
"RoleEmbedding",
379+
"PositionIds",
380+
"DecPositionIds"})
381+
.Outputs({"OutputIds", "ParentIds", "SequenceLength", "OutputScores"})
336382
.Attrs({"decoding_strategy: std::string",
337383
"beam_size: int",
338384
"topk: int",
@@ -352,7 +398,8 @@ PD_BUILD_OP(fusion_unified_decoding)
352398
"pos_bias: bool",
353399
"hidden_act: std::string",
354400
"rel_len: bool",
355-
"early_stopping: bool"})
401+
"early_stopping: bool",
402+
"min_length: int"})
356403
.SetKernelFn(PD_KERNEL(UnifiedDecodingForward))
357404
.SetInferShapeFn(PD_INFER_SHAPE(UnifiedDecodingInferShape))
358405
.SetInferDtypeFn(PD_INFER_DTYPE(UnifiedDecodingInferDtype));

0 commit comments

Comments
 (0)