@@ -785,6 +785,249 @@ def forward(self, src_word, trg_word):
785
785
786
786
return predict
787
787
788
+ def beam_search_v2 (self , src_word , beam_size = 4 , max_len = None , alpha = 0.6 ):
789
+ """
790
+ Beam search with the alive and finished two queues, both have a beam size
791
+ capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as
792
+ steps.
793
+ 1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting
794
+ EOS.
795
+ 2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs
796
+ of next decoding step.
797
+ 3. `grow_finish` compares the already finished candidates in the finished queue
798
+ and newly added finished candidates from `grow_topk`, and selects the top
799
+ `beam_size` finished candidates.
800
+ """
801
+
802
+ def expand_to_beam_size (tensor , beam_size ):
803
+ tensor = paddle .reshape (tensor ,
804
+ [tensor .shape [0 ], 1 ] + tensor .shape [1 :])
805
+ tile_dims = [1 ] * len (tensor .shape )
806
+ tile_dims [1 ] = beam_size
807
+ return paddle .tile (tensor , tile_dims )
808
+
809
+ def merge_beam_dim (tensor ):
810
+ return paddle .reshape (tensor , [- 1 ] + tensor .shape [2 :])
811
+
812
+ # run encoder
813
+ src_max_len = paddle .shape (src_word )[- 1 ]
814
+ src_slf_attn_bias = paddle .cast (
815
+ src_word == self .bos_id ,
816
+ dtype = paddle .get_default_dtype ()).unsqueeze ([1 , 2 ]) * - 1e9
817
+ src_slf_attn_bias .stop_gradient = True
818
+ src_pos = paddle .cast (
819
+ src_word != self .bos_id , dtype = "int64" ) * paddle .arange (
820
+ start = 0 , end = src_max_len )
821
+ src_emb = self .src_word_embedding (src_word )
822
+ src_pos_emb = self .src_pos_embedding (src_pos )
823
+ src_emb = src_emb + src_pos_emb
824
+ enc_input = F .dropout (
825
+ src_emb , p = self .dropout ,
826
+ training = self .training ) if self .dropout else src_emb
827
+
828
+ enc_output = self .transformer .encoder (enc_input , src_slf_attn_bias )
829
+
830
+ # constant number
831
+ inf = float (1. * 1e7 )
832
+ batch_size = enc_output .shape [0 ]
833
+ max_len = (enc_output .shape [1 ] + 20 ) if max_len is None else max_len
834
+
835
+ ### initialize states of beam search ###
836
+ ## init for the alive ##
837
+ initial_log_probs = paddle .to_tensor (
838
+ np .array (
839
+ [[0. ] + [- inf ] * (beam_size - 1 )], dtype = "float32" ))
840
+ alive_log_probs = paddle .tile (initial_log_probs , [batch_size , 1 ])
841
+ alive_seq = paddle .to_tensor (
842
+ np .tile (
843
+ np .array (
844
+ [[[self .bos_id ]]], dtype = "int64" ), (batch_size , beam_size , 1
845
+ )))
846
+
847
+ ## init for the finished ##
848
+ finished_scores = paddle .to_tensor (
849
+ np .array (
850
+ [[- inf ] * beam_size ], dtype = "float32" ))
851
+ finished_scores = paddle .tile (finished_scores , [batch_size , 1 ])
852
+ finished_seq = paddle .to_tensor (
853
+ np .tile (
854
+ np .array (
855
+ [[[self .bos_id ]]], dtype = "int64" ), (batch_size , beam_size , 1
856
+ )))
857
+ finished_flags = paddle .zeros_like (finished_scores )
858
+
859
+ ### initialize inputs and states of transformer decoder ###
860
+ ## init inputs for decoder, shaped `[batch_size*beam_size, ...]`
861
+ trg_word = paddle .reshape (alive_seq [:, :, - 1 ],
862
+ [batch_size * beam_size , 1 ])
863
+ trg_src_attn_bias = src_slf_attn_bias
864
+ trg_src_attn_bias = merge_beam_dim (
865
+ expand_to_beam_size (trg_src_attn_bias , beam_size ))
866
+ enc_output = merge_beam_dim (expand_to_beam_size (enc_output , beam_size ))
867
+
868
+ ## init states (caches) for transformer, need to be updated according to selected beam
869
+ caches = self .transformer .decoder .gen_cache (enc_output , do_zip = False )
870
+
871
+ def update_states (caches , beam_idx , beam_size ):
872
+ new_caches = []
873
+ for cache in caches :
874
+ k = gather_2d_by_gather (cache [0 ].k , beam_idx , beam_size ,
875
+ batch_size , False )
876
+ v = gather_2d_by_gather (cache [0 ].v , beam_idx , beam_size ,
877
+ batch_size , False )
878
+ new_caches .append ((nn .MultiHeadAttention .Cache (k , v ), cache [1 ]))
879
+ return new_caches
880
+
881
+ def gather_2d_by_gather (tensor_nd ,
882
+ beam_idx ,
883
+ beam_size ,
884
+ batch_size ,
885
+ need_flat = True ):
886
+ batch_idx = paddle .arange (
887
+ 0 , batch_size , 1 , dtype = "int64" ) * beam_size
888
+ flat_tensor = merge_beam_dim (tensor_nd ) if need_flat else tensor_nd
889
+ idx = paddle .reshape (
890
+ paddle .add (beam_idx , batch_idx .unsqueeze (- 1 )), [- 1 ])
891
+ new_flat_tensor = paddle .gather (flat_tensor , idx )
892
+ new_tensor_nd = paddle .reshape (
893
+ new_flat_tensor ,
894
+ shape = [batch_size , beam_idx .shape [1 ]] +
895
+ tensor_nd .shape [2 :]) if need_flat else new_flat_tensor
896
+ return new_tensor_nd
897
+
898
+ def early_finish (alive_log_probs , finished_scores ,
899
+ finished_in_finished ):
900
+ max_length_penalty = np .power (((5. + max_len ) / 6. ), alpha )
901
+ # The best possible score of the most likely alive sequence
902
+ lower_bound_alive_scores = alive_log_probs [:,
903
+ 0 ] / max_length_penalty
904
+
905
+ # Now to compute the lowest score of a finished sequence in finished
906
+ # If the sequence isn't finished, we multiply it's score by 0. since
907
+ # scores are all -ve, taking the min will give us the score of the lowest
908
+ # finished item.
909
+ lowest_score_of_fininshed_in_finished = paddle .min (
910
+ finished_scores * finished_in_finished , 1 )
911
+ # If none of the sequences have finished, then the min will be 0 and
912
+ # we have to replace it by -ve INF if it is. The score of any seq in alive
913
+ # will be much higher than -ve INF and the termination condition will not
914
+ # be met.
915
+ lowest_score_of_fininshed_in_finished += (
916
+ 1. - paddle .max (finished_in_finished , 1 )) * - inf
917
+ bound_is_met = paddle .all (
918
+ paddle .greater_than (lowest_score_of_fininshed_in_finished ,
919
+ lower_bound_alive_scores ))
920
+
921
+ return bound_is_met
922
+
923
+ def grow_topk (i , logits , alive_seq , alive_log_probs , states ):
924
+ logits = paddle .reshape (logits , [batch_size , beam_size , - 1 ])
925
+ candidate_log_probs = paddle .log (F .softmax (logits , axis = 2 ))
926
+ log_probs = paddle .add (candidate_log_probs ,
927
+ alive_log_probs .unsqueeze (- 1 ))
928
+
929
+ length_penalty = np .power (5.0 + (i + 1.0 ) / 6.0 , alpha )
930
+ curr_scores = log_probs / length_penalty
931
+ flat_curr_scores = paddle .reshape (curr_scores , [batch_size , - 1 ])
932
+
933
+ topk_scores , topk_ids = paddle .topk (
934
+ flat_curr_scores , k = beam_size * 2 )
935
+
936
+ topk_log_probs = topk_scores * length_penalty
937
+
938
+ topk_beam_index = topk_ids // self .trg_vocab_size
939
+ topk_ids = topk_ids % self .trg_vocab_size
940
+
941
+ # use gather as gather_nd, TODO: use gather_nd
942
+ topk_seq = gather_2d_by_gather (alive_seq , topk_beam_index ,
943
+ beam_size , batch_size )
944
+ topk_seq = paddle .concat (
945
+ [topk_seq , paddle .reshape (topk_ids , topk_ids .shape + [1 ])],
946
+ axis = 2 )
947
+ states = update_states (states , topk_beam_index , beam_size )
948
+ eos = paddle .full (
949
+ shape = topk_ids .shape , dtype = "int64" , fill_value = self .eos_id )
950
+ topk_finished = paddle .cast (paddle .equal (topk_ids , eos ), "float32" )
951
+
952
+ # topk_seq: [batch_size, 2*beam_size, i+1]
953
+ # topk_log_probs, topk_scores, topk_finished: [batch_size, 2*beam_size]
954
+ return topk_seq , topk_log_probs , topk_scores , topk_finished , states
955
+
956
+ def grow_alive (curr_seq , curr_scores , curr_log_probs , curr_finished ,
957
+ states ):
958
+ curr_scores += curr_finished * - inf
959
+ _ , topk_indexes = paddle .topk (curr_scores , k = beam_size )
960
+ alive_seq = gather_2d_by_gather (curr_seq , topk_indexes ,
961
+ beam_size * 2 , batch_size )
962
+ alive_log_probs = gather_2d_by_gather (curr_log_probs , topk_indexes ,
963
+ beam_size * 2 , batch_size )
964
+ states = update_states (states , topk_indexes , beam_size * 2 )
965
+
966
+ return alive_seq , alive_log_probs , states
967
+
968
+ def grow_finished (finished_seq , finished_scores , finished_flags ,
969
+ curr_seq , curr_scores , curr_finished ):
970
+ # finished scores
971
+ finished_seq = paddle .concat (
972
+ [
973
+ finished_seq , paddle .full (
974
+ shape = [batch_size , beam_size , 1 ],
975
+ dtype = "int64" ,
976
+ fill_value = self .eos_id )
977
+ ],
978
+ axis = 2 )
979
+ # Set the scores of the unfinished seq in curr_seq to large negative
980
+ # values
981
+ curr_scores += (1. - curr_finished ) * - inf
982
+ # concatenating the sequences and scores along beam axis
983
+ curr_finished_seq = paddle .concat ([finished_seq , curr_seq ], axis = 1 )
984
+ curr_finished_scores = paddle .concat (
985
+ [finished_scores , curr_scores ], axis = 1 )
986
+ curr_finished_flags = paddle .concat (
987
+ [finished_flags , curr_finished ], axis = 1 )
988
+ _ , topk_indexes = paddle .topk (curr_finished_scores , k = beam_size )
989
+ finished_seq = gather_2d_by_gather (curr_finished_seq , topk_indexes ,
990
+ beam_size * 3 , batch_size )
991
+ finished_scores = gather_2d_by_gather (
992
+ curr_finished_scores , topk_indexes , beam_size * 3 , batch_size )
993
+ finished_flags = gather_2d_by_gather (
994
+ curr_finished_flags , topk_indexes , beam_size * 3 , batch_size )
995
+ return finished_seq , finished_scores , finished_flags
996
+
997
+ for i in range (max_len ):
998
+ trg_pos = paddle .full (
999
+ shape = trg_word .shape , dtype = "int64" , fill_value = i )
1000
+ trg_emb = self .trg_word_embedding (trg_word )
1001
+ trg_pos_emb = self .trg_pos_embedding (trg_pos )
1002
+ trg_emb = trg_emb + trg_pos_emb
1003
+ dec_input = F .dropout (
1004
+ trg_emb , p = self .dropout ,
1005
+ training = self .training ) if self .dropout else trg_emb
1006
+
1007
+ logits , caches = self .transformer .decoder (
1008
+ dec_input , enc_output , None , trg_src_attn_bias , caches )
1009
+ logits = paddle .reshape (
1010
+ logits ,
1011
+ shape = [- 1 , logits .shape [- 1 ]], )
1012
+ logits = self .linear (logits )
1013
+
1014
+ topk_seq , topk_log_probs , topk_scores , topk_finished , states = grow_topk (
1015
+ i , logits , alive_seq , alive_log_probs , caches )
1016
+ alive_seq , alive_log_probs , states = grow_alive (
1017
+ topk_seq , topk_scores , topk_log_probs , topk_finished , states )
1018
+ caches = states
1019
+ finished_seq , finished_scores , finished_flags = grow_finished (
1020
+ finished_seq , finished_scores , finished_flags , topk_seq ,
1021
+ topk_scores , topk_finished )
1022
+ trg_word = paddle .reshape (alive_seq [:, :, - 1 ],
1023
+ [batch_size * beam_size , 1 ])
1024
+
1025
+ if early_finish (alive_log_probs , finished_scores ,
1026
+ finished_flags ).numpy ():
1027
+ break
1028
+
1029
+ return finished_seq , finished_scores
1030
+
788
1031
789
1032
class InferTransformerModel (TransformerModel ):
790
1033
"""
0 commit comments