Skip to content

Commit 18b9ed6

Browse files
author
gongel
committed
feat: add beam_search_v2
1 parent edceb13 commit 18b9ed6

File tree

2 files changed

+359
-0
lines changed

2 files changed

+359
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import os
2+
import yaml
3+
import argparse
4+
from pprint import pprint
5+
from attrdict import AttrDict
6+
7+
import paddle
8+
from paddlenlp.transformers import TransformerModel, position_encoding_init
9+
import reader
10+
11+
12+
def parse_args():
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument(
15+
"--config",
16+
default="./configs/transformer.base.yaml",
17+
type=str,
18+
help="Path of the config file. ")
19+
args = parser.parse_args()
20+
return args
21+
22+
23+
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
24+
"""
25+
Post-process the decoded sequence.
26+
"""
27+
eos_pos = len(seq) - 1
28+
for i, idx in enumerate(seq):
29+
if idx == eos_idx:
30+
eos_pos = i
31+
break
32+
seq = [
33+
idx for idx in seq[:eos_pos + 1]
34+
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
35+
]
36+
return seq
37+
38+
39+
def do_predict(args):
40+
if args.device == "gpu":
41+
place = "gpu"
42+
else:
43+
place = "cpu"
44+
45+
paddle.set_device(place)
46+
47+
# Define data loader
48+
test_loader, to_tokens = reader.create_infer_loader(args)
49+
50+
# Define model
51+
# `TransformerGenerator` automatically chioces using `FasterTransformer`
52+
# (with jit building) or the slower verison `InferTransformerModel`.
53+
transformer = TransformerModel(
54+
src_vocab_size=args.src_vocab_size,
55+
trg_vocab_size=args.trg_vocab_size,
56+
max_length=args.max_length + 1,
57+
num_encoder_layers=args.n_layer,
58+
num_decoder_layers=args.n_layer,
59+
n_head=args.n_head,
60+
d_model=args.d_model,
61+
d_inner_hid=args.d_inner_hid,
62+
dropout=args.dropout,
63+
weight_sharing=args.weight_sharing,
64+
bos_id=args.bos_idx,
65+
eos_id=args.eos_idx)
66+
67+
# Load the trained model
68+
assert args.init_from_params, (
69+
"Please set init_from_params to load the infer model.")
70+
model_dict = paddle.load(
71+
os.path.join(args.init_from_params, "transformer.pdparams"))
72+
73+
# To avoid a longer length than training, reset the size of position
74+
# encoding to max_length
75+
model_dict["src_pos_embedding.pos_encoder.weight"] = position_encoding_init(
76+
args.max_length + 1, args.d_model)
77+
model_dict["trg_pos_embedding.pos_encoder.weight"] = position_encoding_init(
78+
args.max_length + 1, args.d_model)
79+
80+
# Load the model_dict
81+
transformer.load_dict(model_dict)
82+
83+
# Set evaluate mode
84+
transformer.eval()
85+
86+
f = open(args.output_file, "w", encoding="utf-8")
87+
88+
with paddle.no_grad():
89+
for (src_word, ) in test_loader:
90+
# The shape of finished_seq is `[seq_len, batch_size, beam_size]`
91+
# when `output_time_major` argument is `True` for TransformerGenerator.
92+
finished_seq, finished_scores = transformer.beam_search_v2(
93+
src_word=src_word,
94+
beam_size=args.beam_size,
95+
max_len=args.max_out_len,
96+
alpha=0.6)
97+
finished_seq = finished_seq.numpy()
98+
for ins in finished_seq:
99+
for beam_idx, beam in enumerate(ins):
100+
if beam_idx >= args.n_best:
101+
break
102+
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
103+
word_list = to_tokens(id_list)
104+
sequence = " ".join(word_list) + "\n"
105+
f.write(sequence)
106+
f.close()
107+
108+
109+
if __name__ == "__main__":
110+
ARGS = parse_args()
111+
yaml_file = ARGS.config
112+
with open(yaml_file, 'rt') as f:
113+
args = AttrDict(yaml.safe_load(f))
114+
pprint(args)
115+
116+
do_predict(args)

paddlenlp/transformers/transformer/modeling.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,249 @@ def forward(self, src_word, trg_word):
785785

786786
return predict
787787

788+
def beam_search_v2(self, src_word, beam_size=4, max_len=None, alpha=0.6):
789+
"""
790+
Beam search with the alive and finished two queues, both have a beam size
791+
capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as
792+
steps.
793+
1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting
794+
EOS.
795+
2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs
796+
of next decoding step.
797+
3. `grow_finish` compares the already finished candidates in the finished queue
798+
and newly added finished candidates from `grow_topk`, and selects the top
799+
`beam_size` finished candidates.
800+
"""
801+
802+
def expand_to_beam_size(tensor, beam_size):
803+
tensor = paddle.reshape(tensor,
804+
[tensor.shape[0], 1] + tensor.shape[1:])
805+
tile_dims = [1] * len(tensor.shape)
806+
tile_dims[1] = beam_size
807+
return paddle.tile(tensor, tile_dims)
808+
809+
def merge_beam_dim(tensor):
810+
return paddle.reshape(tensor, [-1] + tensor.shape[2:])
811+
812+
# run encoder
813+
src_max_len = paddle.shape(src_word)[-1]
814+
src_slf_attn_bias = paddle.cast(
815+
src_word == self.bos_id,
816+
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
817+
src_slf_attn_bias.stop_gradient = True
818+
src_pos = paddle.cast(
819+
src_word != self.bos_id, dtype="int64") * paddle.arange(
820+
start=0, end=src_max_len)
821+
src_emb = self.src_word_embedding(src_word)
822+
src_pos_emb = self.src_pos_embedding(src_pos)
823+
src_emb = src_emb + src_pos_emb
824+
enc_input = F.dropout(
825+
src_emb, p=self.dropout,
826+
training=self.training) if self.dropout else src_emb
827+
828+
enc_output = self.transformer.encoder(enc_input, src_slf_attn_bias)
829+
830+
# constant number
831+
inf = float(1. * 1e7)
832+
batch_size = enc_output.shape[0]
833+
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
834+
835+
### initialize states of beam search ###
836+
## init for the alive ##
837+
initial_log_probs = paddle.to_tensor(
838+
np.array(
839+
[[0.] + [-inf] * (beam_size - 1)], dtype="float32"))
840+
alive_log_probs = paddle.tile(initial_log_probs, [batch_size, 1])
841+
alive_seq = paddle.to_tensor(
842+
np.tile(
843+
np.array(
844+
[[[self.bos_id]]], dtype="int64"), (batch_size, beam_size, 1
845+
)))
846+
847+
## init for the finished ##
848+
finished_scores = paddle.to_tensor(
849+
np.array(
850+
[[-inf] * beam_size], dtype="float32"))
851+
finished_scores = paddle.tile(finished_scores, [batch_size, 1])
852+
finished_seq = paddle.to_tensor(
853+
np.tile(
854+
np.array(
855+
[[[self.bos_id]]], dtype="int64"), (batch_size, beam_size, 1
856+
)))
857+
finished_flags = paddle.zeros_like(finished_scores)
858+
859+
### initialize inputs and states of transformer decoder ###
860+
## init inputs for decoder, shaped `[batch_size*beam_size, ...]`
861+
trg_word = paddle.reshape(alive_seq[:, :, -1],
862+
[batch_size * beam_size, 1])
863+
trg_src_attn_bias = src_slf_attn_bias
864+
trg_src_attn_bias = merge_beam_dim(
865+
expand_to_beam_size(trg_src_attn_bias, beam_size))
866+
enc_output = merge_beam_dim(expand_to_beam_size(enc_output, beam_size))
867+
868+
## init states (caches) for transformer, need to be updated according to selected beam
869+
caches = self.transformer.decoder.gen_cache(enc_output, do_zip=False)
870+
871+
def update_states(caches, beam_idx, beam_size):
872+
new_caches = []
873+
for cache in caches:
874+
k = gather_2d_by_gather(cache[0].k, beam_idx, beam_size,
875+
batch_size, False)
876+
v = gather_2d_by_gather(cache[0].v, beam_idx, beam_size,
877+
batch_size, False)
878+
new_caches.append((nn.MultiHeadAttention.Cache(k, v), cache[1]))
879+
return new_caches
880+
881+
def gather_2d_by_gather(tensor_nd,
882+
beam_idx,
883+
beam_size,
884+
batch_size,
885+
need_flat=True):
886+
batch_idx = paddle.arange(
887+
0, batch_size, 1, dtype="int64") * beam_size
888+
flat_tensor = merge_beam_dim(tensor_nd) if need_flat else tensor_nd
889+
idx = paddle.reshape(
890+
paddle.add(beam_idx, batch_idx.unsqueeze(-1)), [-1])
891+
new_flat_tensor = paddle.gather(flat_tensor, idx)
892+
new_tensor_nd = paddle.reshape(
893+
new_flat_tensor,
894+
shape=[batch_size, beam_idx.shape[1]] +
895+
tensor_nd.shape[2:]) if need_flat else new_flat_tensor
896+
return new_tensor_nd
897+
898+
def early_finish(alive_log_probs, finished_scores,
899+
finished_in_finished):
900+
max_length_penalty = np.power(((5. + max_len) / 6.), alpha)
901+
# The best possible score of the most likely alive sequence
902+
lower_bound_alive_scores = alive_log_probs[:,
903+
0] / max_length_penalty
904+
905+
# Now to compute the lowest score of a finished sequence in finished
906+
# If the sequence isn't finished, we multiply it's score by 0. since
907+
# scores are all -ve, taking the min will give us the score of the lowest
908+
# finished item.
909+
lowest_score_of_fininshed_in_finished = paddle.min(
910+
finished_scores * finished_in_finished, 1)
911+
# If none of the sequences have finished, then the min will be 0 and
912+
# we have to replace it by -ve INF if it is. The score of any seq in alive
913+
# will be much higher than -ve INF and the termination condition will not
914+
# be met.
915+
lowest_score_of_fininshed_in_finished += (
916+
1. - paddle.max(finished_in_finished, 1)) * -inf
917+
bound_is_met = paddle.all(
918+
paddle.greater_than(lowest_score_of_fininshed_in_finished,
919+
lower_bound_alive_scores))
920+
921+
return bound_is_met
922+
923+
def grow_topk(i, logits, alive_seq, alive_log_probs, states):
924+
logits = paddle.reshape(logits, [batch_size, beam_size, -1])
925+
candidate_log_probs = paddle.log(F.softmax(logits, axis=2))
926+
log_probs = paddle.add(candidate_log_probs,
927+
alive_log_probs.unsqueeze(-1))
928+
929+
length_penalty = np.power(5.0 + (i + 1.0) / 6.0, alpha)
930+
curr_scores = log_probs / length_penalty
931+
flat_curr_scores = paddle.reshape(curr_scores, [batch_size, -1])
932+
933+
topk_scores, topk_ids = paddle.topk(
934+
flat_curr_scores, k=beam_size * 2)
935+
936+
topk_log_probs = topk_scores * length_penalty
937+
938+
topk_beam_index = topk_ids // self.trg_vocab_size
939+
topk_ids = topk_ids % self.trg_vocab_size
940+
941+
# use gather as gather_nd, TODO: use gather_nd
942+
topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index,
943+
beam_size, batch_size)
944+
topk_seq = paddle.concat(
945+
[topk_seq, paddle.reshape(topk_ids, topk_ids.shape + [1])],
946+
axis=2)
947+
states = update_states(states, topk_beam_index, beam_size)
948+
eos = paddle.full(
949+
shape=topk_ids.shape, dtype="int64", fill_value=self.eos_id)
950+
topk_finished = paddle.cast(paddle.equal(topk_ids, eos), "float32")
951+
952+
# topk_seq: [batch_size, 2*beam_size, i+1]
953+
# topk_log_probs, topk_scores, topk_finished: [batch_size, 2*beam_size]
954+
return topk_seq, topk_log_probs, topk_scores, topk_finished, states
955+
956+
def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
957+
states):
958+
curr_scores += curr_finished * -inf
959+
_, topk_indexes = paddle.topk(curr_scores, k=beam_size)
960+
alive_seq = gather_2d_by_gather(curr_seq, topk_indexes,
961+
beam_size * 2, batch_size)
962+
alive_log_probs = gather_2d_by_gather(curr_log_probs, topk_indexes,
963+
beam_size * 2, batch_size)
964+
states = update_states(states, topk_indexes, beam_size * 2)
965+
966+
return alive_seq, alive_log_probs, states
967+
968+
def grow_finished(finished_seq, finished_scores, finished_flags,
969+
curr_seq, curr_scores, curr_finished):
970+
# finished scores
971+
finished_seq = paddle.concat(
972+
[
973+
finished_seq, paddle.full(
974+
shape=[batch_size, beam_size, 1],
975+
dtype="int64",
976+
fill_value=self.eos_id)
977+
],
978+
axis=2)
979+
# Set the scores of the unfinished seq in curr_seq to large negative
980+
# values
981+
curr_scores += (1. - curr_finished) * -inf
982+
# concatenating the sequences and scores along beam axis
983+
curr_finished_seq = paddle.concat([finished_seq, curr_seq], axis=1)
984+
curr_finished_scores = paddle.concat(
985+
[finished_scores, curr_scores], axis=1)
986+
curr_finished_flags = paddle.concat(
987+
[finished_flags, curr_finished], axis=1)
988+
_, topk_indexes = paddle.topk(curr_finished_scores, k=beam_size)
989+
finished_seq = gather_2d_by_gather(curr_finished_seq, topk_indexes,
990+
beam_size * 3, batch_size)
991+
finished_scores = gather_2d_by_gather(
992+
curr_finished_scores, topk_indexes, beam_size * 3, batch_size)
993+
finished_flags = gather_2d_by_gather(
994+
curr_finished_flags, topk_indexes, beam_size * 3, batch_size)
995+
return finished_seq, finished_scores, finished_flags
996+
997+
for i in range(max_len):
998+
trg_pos = paddle.full(
999+
shape=trg_word.shape, dtype="int64", fill_value=i)
1000+
trg_emb = self.trg_word_embedding(trg_word)
1001+
trg_pos_emb = self.trg_pos_embedding(trg_pos)
1002+
trg_emb = trg_emb + trg_pos_emb
1003+
dec_input = F.dropout(
1004+
trg_emb, p=self.dropout,
1005+
training=self.training) if self.dropout else trg_emb
1006+
1007+
logits, caches = self.transformer.decoder(
1008+
dec_input, enc_output, None, trg_src_attn_bias, caches)
1009+
logits = paddle.reshape(
1010+
logits,
1011+
shape=[-1, logits.shape[-1]], )
1012+
logits = self.linear(logits)
1013+
1014+
topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk(
1015+
i, logits, alive_seq, alive_log_probs, caches)
1016+
alive_seq, alive_log_probs, states = grow_alive(
1017+
topk_seq, topk_scores, topk_log_probs, topk_finished, states)
1018+
caches = states
1019+
finished_seq, finished_scores, finished_flags = grow_finished(
1020+
finished_seq, finished_scores, finished_flags, topk_seq,
1021+
topk_scores, topk_finished)
1022+
trg_word = paddle.reshape(alive_seq[:, :, -1],
1023+
[batch_size * beam_size, 1])
1024+
1025+
if early_finish(alive_log_probs, finished_scores,
1026+
finished_flags).numpy():
1027+
break
1028+
1029+
return finished_seq, finished_scores
1030+
7881031

7891032
class InferTransformerModel(TransformerModel):
7901033
"""

0 commit comments

Comments
 (0)