@@ -19,10 +19,11 @@ limitations under the License. */
19
19
20
20
21
21
std::vector<paddle::Tensor> UnifiedDecodingForward (
22
- const std::vector< paddle::Tensor>& cache_k ,
23
- const std::vector< paddle::Tensor>& cache_v ,
22
+ const paddle::Tensor& input_ids ,
23
+ const paddle::Tensor& attn_mask ,
24
24
const paddle::Tensor& mem_seq_len,
25
25
const paddle::Tensor& type_id,
26
+ const paddle::Tensor& decoder_type_id,
26
27
const paddle::Tensor& logits_mask,
27
28
const paddle::Tensor& word_embedding,
28
29
const std::vector<paddle::Tensor>& self_ln_weight,
@@ -51,6 +52,11 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
51
52
const paddle::Tensor& embedding_bias,
52
53
const paddle::Tensor& positional_embedding_weight,
53
54
const paddle::Tensor& type_embedding_weight,
55
+ const paddle::Tensor& role_id,
56
+ const paddle::Tensor& decoder_role_id,
57
+ const paddle::Tensor& role_embedding_table,
58
+ const paddle::Tensor& position_ids,
59
+ const paddle::Tensor& decoder_position_ids,
54
60
const std::string& decoding_strategy,
55
61
const int & beam_size,
56
62
const int & topk,
@@ -70,19 +76,22 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
70
76
const bool & pos_bias,
71
77
const std::string& hidden_act,
72
78
const bool & rel_len,
73
- const bool & early_stopping) {
74
- int batch_size = cache_k[0 ].shape ()[0 ];
75
- int max_out_len = rel_len ? max_len + cache_k[0 ].shape ()[2 ] : max_len;
79
+ const bool & early_stopping,
80
+ const int & min_length) {
81
+ int batch_size = input_ids.shape ()[0 ];
82
+ int max_out_len = rel_len ? max_len + input_ids.shape ()[1 ] : max_len;
76
83
77
- std::vector<int64_t > output_dims;
84
+ std::vector<int64_t > output_ids_dims;
85
+ std::vector<int64_t > output_scores_dims;
78
86
std::vector<int64_t > parent_ids_dims;
79
87
std::vector<int64_t > sequence_length_dims ({batch_size});
80
88
if (decoding_strategy == " beam_search" ) {
81
89
if (batch_size != -1 ) {
82
90
batch_size /= beam_size;
83
91
}
84
- output_dims = {max_out_len, batch_size, beam_size};
85
- parent_ids_dims = output_dims;
92
+ output_ids_dims = {max_out_len, batch_size, beam_size};
93
+ output_scores_dims = {batch_size, beam_size};
94
+ parent_ids_dims = output_ids_dims;
86
95
} else if (decoding_strategy == " beam_search_v2" ||
87
96
decoding_strategy == " beam_search_v3" ) {
88
97
// Use separated alive and finish beam queues to avoid the decrease of alive
@@ -94,22 +103,25 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
94
103
} else {
95
104
sequence_length_dims = {batch_size};
96
105
}
97
- output_dims = {max_out_len, batch_size, beam_size * 2 };
98
- parent_ids_dims = output_dims;
106
+ output_ids_dims = {max_out_len, batch_size, beam_size * 2 };
107
+ output_scores_dims = {batch_size, beam_size * 2 };
108
+ parent_ids_dims = output_ids_dims;
99
109
} else if (decoding_strategy == " topk_sampling" ||
100
110
decoding_strategy == " topp_sampling" ||
101
111
decoding_strategy == " sampling" ) {
102
- output_dims = {max_out_len, batch_size};
112
+ output_ids_dims = {max_out_len, batch_size};
113
+ output_scores_dims = {batch_size};
103
114
parent_ids_dims = {1 };
104
115
} else {
105
116
PD_THROW (" Not supported decoding strategy. " );
106
117
}
107
- auto output_ids = paddle::Tensor (cache_k[ 0 ] .place (), output_dims );
108
- auto parent_ids = paddle::Tensor (cache_k[ 0 ] .place (), parent_ids_dims);
118
+ auto output_ids = paddle::Tensor (input_ids .place (), output_ids_dims );
119
+ auto parent_ids = paddle::Tensor (input_ids .place (), parent_ids_dims);
109
120
auto sequence_length =
110
- paddle::Tensor (cache_k[0 ].place (), sequence_length_dims);
121
+ paddle::Tensor (input_ids.place (), sequence_length_dims);
122
+ auto output_scores = paddle::Tensor (input_ids.place (), output_scores_dims);
111
123
112
- if (cache_k[ 0 ] .place () == paddle::PlaceType::kGPU ) {
124
+ if (input_ids .place () == paddle::PlaceType::kGPU ) {
113
125
auto mem_seq_length = paddle::Tensor (paddle::PlaceType::kGPU );
114
126
115
127
if (mem_seq_len.place () != paddle::PlaceType::kGPU ) {
@@ -118,10 +130,11 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
118
130
mem_seq_length = mem_seq_len;
119
131
}
120
132
121
- return UnifiedDecodingCUDAForward (cache_k ,
122
- cache_v ,
133
+ return UnifiedDecodingCUDAForward (input_ids ,
134
+ attn_mask ,
123
135
mem_seq_length,
124
136
type_id,
137
+ decoder_type_id,
125
138
logits_mask,
126
139
word_embedding,
127
140
self_ln_weight,
@@ -150,9 +163,15 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
150
163
embedding_bias,
151
164
positional_embedding_weight,
152
165
type_embedding_weight,
166
+ role_id,
167
+ decoder_role_id,
168
+ role_embedding_table,
169
+ position_ids,
170
+ decoder_position_ids,
153
171
output_ids,
154
172
parent_ids,
155
173
sequence_length,
174
+ output_scores,
156
175
decoding_strategy,
157
176
beam_size,
158
177
topk,
@@ -171,17 +190,20 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
171
190
normalize_before,
172
191
pos_bias,
173
192
hidden_act,
174
- early_stopping);
193
+ early_stopping,
194
+ min_length);
175
195
} else {
176
196
PD_THROW (" Not implemented place. Only GPU is supported. " );
177
197
}
178
198
}
179
199
180
200
std::vector<std::vector<int64_t >> UnifiedDecodingInferShape (
181
- const std::vector<std::vector< int64_t >>& cache_k_shapes ,
182
- const std::vector<std::vector< int64_t >>& cache_v_shapes ,
201
+ const std::vector<int64_t >& input_ids_shape ,
202
+ const std::vector<int64_t >& attn_mask_shape ,
183
203
const std::vector<int64_t >& mem_seq_len_shape,
184
204
const std::vector<int64_t >& logits_mask_shape,
205
+ const std::vector<int64_t >& type_id_shape,
206
+ const std::vector<int64_t >& decoder_type_id_shape,
185
207
const std::vector<int64_t >& word_embedding_shape,
186
208
const std::vector<std::vector<int64_t >>& self_ln_weight_shapes,
187
209
const std::vector<std::vector<int64_t >>& self_ln_bias_shapes,
@@ -209,6 +231,11 @@ std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
209
231
const std::vector<int64_t >& embedding_bias_shape,
210
232
const std::vector<int64_t >& positional_embedding_weight_shape,
211
233
const std::vector<int64_t >& type_embedding_weight_shape,
234
+ const std::vector<int64_t >& role_id_shape,
235
+ const std::vector<int64_t >& decoder_role_id_shape,
236
+ const std::vector<int64_t >& role_embedding_table_shape,
237
+ const std::vector<int64_t >& position_ids_shape,
238
+ const std::vector<int64_t >& decoder_position_ids_shape,
212
239
const std::string& decoding_strategy,
213
240
const int & beam_size,
214
241
const int & topk,
@@ -228,17 +255,20 @@ std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
228
255
const bool & pos_bias,
229
256
const std::string& hidden_act,
230
257
const bool & rel_len,
231
- const bool & early_stopping) {
232
- int batch_size = cache_k_shapes[0 ][0 ];
258
+ const bool & early_stopping,
259
+ const int & min_length) {
260
+ int batch_size = input_ids_shape[0 ];
233
261
234
- std::vector<int64_t > output_dims;
262
+ std::vector<int64_t > output_ids_dims;
263
+ std::vector<int64_t > output_scores_dims;
235
264
std::vector<int64_t > sequence_length_dims ({batch_size});
236
265
if (decoding_strategy == " beam_search" ) {
237
266
if (batch_size != -1 ) {
238
267
batch_size /= beam_size;
239
268
}
240
- output_dims = {max_len, batch_size, beam_size};
241
- return {output_dims, output_dims, sequence_length_dims};
269
+ output_ids_dims = {max_len, batch_size, beam_size};
270
+ output_scores_dims = {batch_size, beam_size};
271
+ return {output_ids_dims, output_ids_dims, sequence_length_dims, output_scores_dims};
242
272
} else if (decoding_strategy == " beam_search_v2" ||
243
273
decoding_strategy == " beam_search_v3" ) {
244
274
// Use separated alive and finish beam queues to avoid the decrease of alive
@@ -250,23 +280,27 @@ std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
250
280
} else {
251
281
sequence_length_dims = {batch_size};
252
282
}
253
- output_dims = {max_len, batch_size, beam_size * 2 };
254
- return {output_dims, output_dims, sequence_length_dims};
283
+ output_ids_dims = {max_len, batch_size, beam_size * 2 };
284
+ output_scores_dims = {batch_size, beam_size * 2 };
285
+ return {output_ids_dims, output_ids_dims, sequence_length_dims, output_scores_dims};
255
286
} else if (decoding_strategy == " topk_sampling" ||
256
287
decoding_strategy == " topp_sampling" ||
257
288
decoding_strategy == " sampling" ) {
258
- output_dims = {max_len, batch_size};
259
- return {output_dims, {1 }, sequence_length_dims};
289
+ output_ids_dims = {max_len, batch_size};
290
+ output_scores_dims = {batch_size};
291
+ return {output_ids_dims, {1 }, sequence_length_dims, output_scores_dims};
260
292
} else {
261
293
PD_THROW (" Not supported decoding strategy. " );
262
294
}
263
295
}
264
296
265
297
std::vector<paddle::DataType> UnifiedDecodingInferDtype (
266
- const std::vector< paddle::DataType>& cache_k ,
267
- const std::vector< paddle::DataType>& cache_v ,
298
+ const paddle::DataType& input_ids ,
299
+ const paddle::DataType& attn_mask ,
268
300
const paddle::DataType& mem_seq_len,
269
301
const paddle::DataType& logits_mask,
302
+ const paddle::DataType& type_id,
303
+ const paddle::DataType& decoder_type_id,
270
304
const paddle::DataType& word_embedding,
271
305
const std::vector<paddle::DataType>& self_ln_weight,
272
306
const std::vector<paddle::DataType>& self_ln_bias,
@@ -293,17 +327,24 @@ std::vector<paddle::DataType> UnifiedDecodingInferDtype(
293
327
const paddle::DataType& embedding_weight,
294
328
const paddle::DataType& embedding_bias,
295
329
const paddle::DataType& positional_embedding_weight,
296
- const paddle::DataType& type_embedding_weight) {
330
+ const paddle::DataType& type_embedding_weight,
331
+ const paddle::DataType& role_id,
332
+ const paddle::DataType& decoder_role_id,
333
+ const paddle::DataType& role_embedding_table,
334
+ const paddle::DataType& position_ids,
335
+ const paddle::DataType& decoder_position_ids) {
297
336
return {paddle::DataType::INT32,
298
337
paddle::DataType::INT32,
299
- paddle::DataType::INT32};
338
+ paddle::DataType::INT32,
339
+ paddle::DataType::FLOAT32};
300
340
}
301
341
302
342
PD_BUILD_OP (fusion_unified_decoding)
303
- .Inputs({paddle::Vec ( " CacheK " ) ,
304
- paddle::Vec ( " CacheV " ) ,
343
+ .Inputs({" InputIds " ,
344
+ " AttnMask " ,
305
345
" MemSeqLen" ,
306
- " TypeId" ,
346
+ " TypeIds" ,
347
+ " DecTypeIds" ,
307
348
" LogitsMask" ,
308
349
" WordEmbedding" ,
309
350
paddle::Vec (" SelfLayernormWeight" ),
@@ -331,8 +372,13 @@ PD_BUILD_OP(fusion_unified_decoding)
331
372
" EmbWeight" ,
332
373
" EmbBias" ,
333
374
" PositionEncEmb" ,
334
- " TypeEmb" })
335
- .Outputs({" OutputIds" , " ParentIds" , " SequenceLength" })
375
+ " TypeEmb" ,
376
+ " RoleIds" ,
377
+ " DecRoleIds" ,
378
+ " RoleEmbedding" ,
379
+ " PositionIds" ,
380
+ " DecPositionIds" })
381
+ .Outputs({" OutputIds" , " ParentIds" , " SequenceLength" , " OutputScores" })
336
382
.Attrs({" decoding_strategy: std::string" ,
337
383
" beam_size: int" ,
338
384
" topk: int" ,
@@ -352,7 +398,8 @@ PD_BUILD_OP(fusion_unified_decoding)
352
398
" pos_bias: bool" ,
353
399
" hidden_act: std::string" ,
354
400
" rel_len: bool" ,
355
- " early_stopping: bool" })
401
+ " early_stopping: bool" ,
402
+ " min_length: int" })
356
403
.SetKernelFn(PD_KERNEL(UnifiedDecodingForward))
357
404
.SetInferShapeFn(PD_INFER_SHAPE(UnifiedDecodingInferShape))
358
405
.SetInferDtypeFn(PD_INFER_DTYPE(UnifiedDecodingInferDtype));
0 commit comments