@@ -1108,6 +1108,9 @@ class InferTransformerModel(TransformerModel):
1108
1108
`[batch_size, seq_len, beam_size]`. If `True`, the data layout would
1109
1109
be time major with shape `[seq_len, batch_size, beam_size]`. Default
1110
1110
to `False`.
1111
+ beam_search_version (str): Specify beam search version. It should be in one
1112
+ of [`v1`, `v2`]. If `v2`, need to set `alpha`(default to 0.6) for length
1113
+ penalty. Default to `v1`.
1111
1114
"""
1112
1115
1113
1116
def __init__ (self ,
@@ -1127,14 +1130,23 @@ def __init__(self,
1127
1130
eos_id = 1 ,
1128
1131
beam_size = 4 ,
1129
1132
max_out_len = 256 ,
1130
- output_time_major = False ):
1133
+ output_time_major = False ,
1134
+ beam_search_version = 'v1' ,
1135
+ ** kwargs ):
1131
1136
args = dict (locals ())
1132
1137
args .pop ("self" )
1133
1138
args .pop ("__class__" , None )
1134
1139
self .beam_size = args .pop ("beam_size" )
1135
1140
self .max_out_len = args .pop ("max_out_len" )
1136
1141
self .output_time_major = args .pop ("output_time_major" )
1137
1142
self .dropout = dropout
1143
+ self .beam_search_version = args .pop ('beam_search_version' )
1144
+ kwargs = args .pop ("kwargs" )
1145
+ if self .beam_search_version == 'v2' :
1146
+ if 'alpha' in kwargs :
1147
+ self .alpha = kwargs ['alpha' ]
1148
+ else :
1149
+ self .alpha = 0.6
1138
1150
super (InferTransformerModel , self ).__init__ (** args )
1139
1151
1140
1152
cell = TransformerDecodeCell (
@@ -1191,48 +1203,59 @@ def forward(self, src_word, trg_word=None):
1191
1203
transformer(
1192
1204
src_word=paddle.randint(low=3, high=30000, shape=[batch_size, seq_len]))
1193
1205
"""
1194
- src_max_len = paddle .shape (src_word )[- 1 ]
1195
- src_slf_attn_bias = paddle .cast (
1196
- src_word == self .bos_id ,
1197
- dtype = paddle .get_default_dtype ()).unsqueeze ([1 , 2 ]) * - 1e9
1198
- trg_src_attn_bias = src_slf_attn_bias
1199
- src_pos = paddle .cast (
1200
- src_word != self .bos_id , dtype = "int64" ) * paddle .arange (
1201
- start = 0 , end = src_max_len )
1206
+ if self .beam_search_version == 'v1' :
1207
+ src_max_len = paddle .shape (src_word )[- 1 ]
1208
+ src_slf_attn_bias = paddle .cast (
1209
+ src_word == self .bos_id ,
1210
+ dtype = paddle .get_default_dtype ()).unsqueeze ([1 , 2 ]) * - 1e9
1211
+ trg_src_attn_bias = src_slf_attn_bias
1212
+ src_pos = paddle .cast (
1213
+ src_word != self .bos_id , dtype = "int64" ) * paddle .arange (
1214
+ start = 0 , end = src_max_len )
1215
+
1216
+ # Run encoder
1217
+ src_emb = self .src_word_embedding (src_word )
1218
+ src_pos_emb = self .src_pos_embedding (src_pos )
1219
+ src_emb = src_emb + src_pos_emb
1220
+ enc_input = F .dropout (
1221
+ src_emb , p = self .dropout ,
1222
+ training = False ) if self .dropout else src_emb
1223
+ enc_output = self .transformer .encoder (enc_input , src_slf_attn_bias )
1202
1224
1203
- # Run encoder
1204
- src_emb = self .src_word_embedding (src_word )
1205
- src_pos_emb = self .src_pos_embedding (src_pos )
1206
- src_emb = src_emb + src_pos_emb
1207
- enc_input = F .dropout (
1208
- src_emb , p = self .dropout ,
1209
- training = False ) if self .dropout else src_emb
1210
- enc_output = self .transformer .encoder (enc_input , src_slf_attn_bias )
1225
+ # Init states (caches) for transformer, need to be updated according to selected beam
1226
+ incremental_cache , static_cache = self .transformer .decoder .gen_cache (
1227
+ enc_output , do_zip = True )
1211
1228
1212
- # Init states (caches) for transformer, need to be updated according to selected beam
1213
- incremental_cache , static_cache = self .transformer .decoder .gen_cache (
1214
- enc_output , do_zip = True )
1229
+ static_cache , enc_output , trg_src_attn_bias = TransformerBeamSearchDecoder .tile_beam_merge_with_batch (
1230
+ (static_cache , enc_output , trg_src_attn_bias ), self .beam_size )
1215
1231
1216
- static_cache , enc_output , trg_src_attn_bias = TransformerBeamSearchDecoder .tile_beam_merge_with_batch (
1217
- (static_cache , enc_output , trg_src_attn_bias ), self .beam_size )
1232
+ if trg_word is not None :
1233
+ trg_length = paddle .sum (paddle .cast (
1234
+ trg_word != self .bos_id , dtype = "int64" ),
1235
+ axis = - 1 )
1236
+ else :
1237
+ trg_length = None
1238
+
1239
+ rs , _ = nn .decode .dynamic_decode (
1240
+ decoder = self .decode ,
1241
+ inits = incremental_cache ,
1242
+ max_step_num = self .max_out_len ,
1243
+ memory = enc_output ,
1244
+ trg_src_attn_bias = trg_src_attn_bias ,
1245
+ static_cache = static_cache ,
1246
+ is_test = True ,
1247
+ output_time_major = self .output_time_major ,
1248
+ trg_word = trg_word ,
1249
+ trg_length = trg_length )
1250
+
1251
+ return rs
1252
+
1253
+ elif self .beam_search_version == 'v2' :
1254
+ finished_seq , finished_scores = self .beam_search_v2 (
1255
+ src_word , self .beam_size , self .max_out_len , self .alpha )
1256
+ if self .output_time_major :
1257
+ finished_seq = finished_seq .transpose ([2 , 0 , 1 ])
1258
+ else :
1259
+ finished_seq = finished_seq .transpose ([0 , 2 , 1 ])
1218
1260
1219
- if trg_word is not None :
1220
- trg_length = paddle .sum (paddle .cast (
1221
- trg_word != self .bos_id , dtype = "int64" ),
1222
- axis = - 1 )
1223
- else :
1224
- trg_length = None
1225
-
1226
- rs , _ = nn .decode .dynamic_decode (
1227
- decoder = self .decode ,
1228
- inits = incremental_cache ,
1229
- max_step_num = self .max_out_len ,
1230
- memory = enc_output ,
1231
- trg_src_attn_bias = trg_src_attn_bias ,
1232
- static_cache = static_cache ,
1233
- is_test = True ,
1234
- output_time_major = self .output_time_major ,
1235
- trg_word = trg_word ,
1236
- trg_length = trg_length )
1237
-
1238
- return rs
1261
+ return finished_seq
0 commit comments