Skip to content

Commit 85ec79b

Browse files
author
gongel
committed
refactor: combine beam search v1 and v2
1 parent 8c8fc03 commit 85ec79b

File tree

3 files changed

+72
-45
lines changed

3 files changed

+72
-45
lines changed

examples/machine_translation/transformer/predict.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def do_predict(args):
8484
eos_id=args.eos_idx,
8585
beam_size=args.beam_size,
8686
max_out_len=args.max_out_len,
87-
use_ft=not args.without_ft)
87+
use_ft=not args.without_ft,
88+
beam_search_version='v2',
89+
alpha=0.6)
8890

8991
# Load the trained model
9092
assert args.init_from_params, (

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,8 @@ def __init__(self,
344344
eos_id=eos_id,
345345
beam_size=beam_size,
346346
max_out_len=max_out_len,
347-
output_time_major=self.output_time_major)
347+
output_time_major=self.output_time_major,
348+
**kwargs)
348349
else:
349350
self.transformer = InferTransformerModel(
350351
src_vocab_size=src_vocab_size,
@@ -361,7 +362,8 @@ def __init__(self,
361362
eos_id=eos_id,
362363
beam_size=beam_size,
363364
max_out_len=max_out_len,
364-
output_time_major=self.output_time_major)
365+
output_time_major=self.output_time_major,
366+
**kwargs)
365367

366368
def forward(self, src_word):
367369
r"""

paddlenlp/transformers/transformer/modeling.py

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,9 @@ class InferTransformerModel(TransformerModel):
11081108
`[batch_size, seq_len, beam_size]`. If `True`, the data layout would
11091109
be time major with shape `[seq_len, batch_size, beam_size]`. Default
11101110
to `False`.
1111+
beam_search_version (str): Specify beam search version. It should be in one
1112+
of [`v1`, `v2`]. If `v2`, need to set `alpha`(default to 0.6) for length
1113+
penalty. Default to `v1`.
11111114
"""
11121115

11131116
def __init__(self,
@@ -1127,14 +1130,23 @@ def __init__(self,
11271130
eos_id=1,
11281131
beam_size=4,
11291132
max_out_len=256,
1130-
output_time_major=False):
1133+
output_time_major=False,
1134+
beam_search_version='v1',
1135+
**kwargs):
11311136
args = dict(locals())
11321137
args.pop("self")
11331138
args.pop("__class__", None)
11341139
self.beam_size = args.pop("beam_size")
11351140
self.max_out_len = args.pop("max_out_len")
11361141
self.output_time_major = args.pop("output_time_major")
11371142
self.dropout = dropout
1143+
self.beam_search_version = args.pop('beam_search_version')
1144+
kwargs = args.pop("kwargs")
1145+
if self.beam_search_version == 'v2':
1146+
if 'alpha' in kwargs:
1147+
self.alpha = kwargs['alpha']
1148+
else:
1149+
self.alpha = 0.6
11381150
super(InferTransformerModel, self).__init__(**args)
11391151

11401152
cell = TransformerDecodeCell(
@@ -1191,48 +1203,59 @@ def forward(self, src_word, trg_word=None):
11911203
transformer(
11921204
src_word=paddle.randint(low=3, high=30000, shape=[batch_size, seq_len]))
11931205
"""
1194-
src_max_len = paddle.shape(src_word)[-1]
1195-
src_slf_attn_bias = paddle.cast(
1196-
src_word == self.bos_id,
1197-
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
1198-
trg_src_attn_bias = src_slf_attn_bias
1199-
src_pos = paddle.cast(
1200-
src_word != self.bos_id, dtype="int64") * paddle.arange(
1201-
start=0, end=src_max_len)
1206+
if self.beam_search_version == 'v1':
1207+
src_max_len = paddle.shape(src_word)[-1]
1208+
src_slf_attn_bias = paddle.cast(
1209+
src_word == self.bos_id,
1210+
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
1211+
trg_src_attn_bias = src_slf_attn_bias
1212+
src_pos = paddle.cast(
1213+
src_word != self.bos_id, dtype="int64") * paddle.arange(
1214+
start=0, end=src_max_len)
1215+
1216+
# Run encoder
1217+
src_emb = self.src_word_embedding(src_word)
1218+
src_pos_emb = self.src_pos_embedding(src_pos)
1219+
src_emb = src_emb + src_pos_emb
1220+
enc_input = F.dropout(
1221+
src_emb, p=self.dropout,
1222+
training=False) if self.dropout else src_emb
1223+
enc_output = self.transformer.encoder(enc_input, src_slf_attn_bias)
12021224

1203-
# Run encoder
1204-
src_emb = self.src_word_embedding(src_word)
1205-
src_pos_emb = self.src_pos_embedding(src_pos)
1206-
src_emb = src_emb + src_pos_emb
1207-
enc_input = F.dropout(
1208-
src_emb, p=self.dropout,
1209-
training=False) if self.dropout else src_emb
1210-
enc_output = self.transformer.encoder(enc_input, src_slf_attn_bias)
1225+
# Init states (caches) for transformer, need to be updated according to selected beam
1226+
incremental_cache, static_cache = self.transformer.decoder.gen_cache(
1227+
enc_output, do_zip=True)
12111228

1212-
# Init states (caches) for transformer, need to be updated according to selected beam
1213-
incremental_cache, static_cache = self.transformer.decoder.gen_cache(
1214-
enc_output, do_zip=True)
1229+
static_cache, enc_output, trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
1230+
(static_cache, enc_output, trg_src_attn_bias), self.beam_size)
12151231

1216-
static_cache, enc_output, trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
1217-
(static_cache, enc_output, trg_src_attn_bias), self.beam_size)
1232+
if trg_word is not None:
1233+
trg_length = paddle.sum(paddle.cast(
1234+
trg_word != self.bos_id, dtype="int64"),
1235+
axis=-1)
1236+
else:
1237+
trg_length = None
1238+
1239+
rs, _ = nn.decode.dynamic_decode(
1240+
decoder=self.decode,
1241+
inits=incremental_cache,
1242+
max_step_num=self.max_out_len,
1243+
memory=enc_output,
1244+
trg_src_attn_bias=trg_src_attn_bias,
1245+
static_cache=static_cache,
1246+
is_test=True,
1247+
output_time_major=self.output_time_major,
1248+
trg_word=trg_word,
1249+
trg_length=trg_length)
1250+
1251+
return rs
1252+
1253+
elif self.beam_search_version == 'v2':
1254+
finished_seq, finished_scores = self.beam_search_v2(
1255+
src_word, self.beam_size, self.max_out_len, self.alpha)
1256+
if self.output_time_major:
1257+
finished_seq = finished_seq.transpose([2, 0, 1])
1258+
else:
1259+
finished_seq = finished_seq.transpose([0, 2, 1])
12181260

1219-
if trg_word is not None:
1220-
trg_length = paddle.sum(paddle.cast(
1221-
trg_word != self.bos_id, dtype="int64"),
1222-
axis=-1)
1223-
else:
1224-
trg_length = None
1225-
1226-
rs, _ = nn.decode.dynamic_decode(
1227-
decoder=self.decode,
1228-
inits=incremental_cache,
1229-
max_step_num=self.max_out_len,
1230-
memory=enc_output,
1231-
trg_src_attn_bias=trg_src_attn_bias,
1232-
static_cache=static_cache,
1233-
is_test=True,
1234-
output_time_major=self.output_time_major,
1235-
trg_word=trg_word,
1236-
trg_length=trg_length)
1237-
1238-
return rs
1261+
return finished_seq

0 commit comments

Comments
 (0)