Skip to content

Commit b22cd96

Browse files
authored
Merge pull request #1761 from jacquesqiao/beam_search
support Beam search in v2 api
2 parents 4b5a432 + b669b5f commit b22cd96

File tree

4 files changed

+247
-109
lines changed

4 files changed

+247
-109
lines changed

demo/seqToseq/api_train_v2.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import sys
2+
23
import paddle.v2 as paddle
34

45

5-
def seqToseq_net(source_dict_dim, target_dict_dim):
6+
def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
67
### Network Architecture
78
word_vector_dim = 512 # dimension of word vector
89
decoder_size = 512 # dimension of hidden unit in GRU Decoder network
910
encoder_size = 512 # dimension of hidden unit in GRU Encoder network
1011

12+
beam_size = 3
13+
max_length = 250
14+
1115
#### Encoder
1216
src_word_id = paddle.layer.data(
1317
name='source_language_word',
@@ -67,30 +71,57 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
6771
group_input2 = paddle.layer.StaticInputV2(input=encoded_proj, is_seq=True)
6872
group_inputs = [group_input1, group_input2]
6973

70-
trg_embedding = paddle.layer.embedding(
71-
input=paddle.layer.data(
72-
name='target_language_word',
73-
type=paddle.data_type.integer_value_sequence(target_dict_dim)),
74-
size=word_vector_dim,
75-
param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
76-
group_inputs.append(trg_embedding)
77-
78-
# For decoder equipped with attention mechanism, in training,
79-
# target embeding (the groudtruth) is the data input,
80-
# while encoded source sequence is accessed to as an unbounded memory.
81-
# Here, the StaticInput defines a read-only memory
82-
# for the recurrent_group.
83-
decoder = paddle.layer.recurrent_group(
84-
name=decoder_group_name,
85-
step=gru_decoder_with_attention,
86-
input=group_inputs)
87-
88-
lbl = paddle.layer.data(
89-
name='target_language_next_word',
90-
type=paddle.data_type.integer_value_sequence(target_dict_dim))
91-
cost = paddle.layer.classification_cost(input=decoder, label=lbl)
92-
93-
return cost
74+
if not is_generating:
75+
trg_embedding = paddle.layer.embedding(
76+
input=paddle.layer.data(
77+
name='target_language_word',
78+
type=paddle.data_type.integer_value_sequence(target_dict_dim)),
79+
size=word_vector_dim,
80+
param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
81+
group_inputs.append(trg_embedding)
82+
83+
# For decoder equipped with attention mechanism, in training,
84+
# target embeding (the groudtruth) is the data input,
85+
# while encoded source sequence is accessed to as an unbounded memory.
86+
# Here, the StaticInput defines a read-only memory
87+
# for the recurrent_group.
88+
decoder = paddle.layer.recurrent_group(
89+
name=decoder_group_name,
90+
step=gru_decoder_with_attention,
91+
input=group_inputs)
92+
93+
lbl = paddle.layer.data(
94+
name='target_language_next_word',
95+
type=paddle.data_type.integer_value_sequence(target_dict_dim))
96+
cost = paddle.layer.classification_cost(input=decoder, label=lbl)
97+
98+
return cost
99+
else:
100+
# In generation, the decoder predicts a next target word based on
101+
# the encoded source sequence and the last generated target word.
102+
103+
# The encoded source sequence (encoder's output) must be specified by
104+
# StaticInput, which is a read-only memory.
105+
# Embedding of the last generated word is automatically gotten by
106+
# GeneratedInputs, which is initialized by a start mark, such as <s>,
107+
# and must be included in generation.
108+
109+
trg_embedding = paddle.layer.GeneratedInputV2(
110+
size=target_dict_dim,
111+
embedding_name='_target_language_embedding',
112+
embedding_size=word_vector_dim)
113+
group_inputs.append(trg_embedding)
114+
115+
beam_gen = paddle.layer.beam_search(
116+
name=decoder_group_name,
117+
step=gru_decoder_with_attention,
118+
input=group_inputs,
119+
bos_id=0,
120+
eos_id=1,
121+
beam_size=beam_size,
122+
max_length=max_length)
123+
124+
return beam_gen
94125

95126

96127
def main():

python/paddle/v2/config_base.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,16 @@ def __init__(self, name=None, parent_layers=None):
6767
self.name = name
6868
self.__context__ = {}
6969
self.__parent_layers__ = parent_layers
70-
self.__children_layers__ = [] # used for evaluator.
70+
# some layer may have some extra parent layer
71+
self.__extra_parent__ = []
72+
# used for evaluator.
73+
self.__children_layers__ = []
74+
75+
def extra_parent(self):
76+
return self.__extra_parent__
77+
78+
def append_extra_parent(self, parent):
79+
self.__extra_parent__.append(parent)
7180

7281
def append_child(self, layer, parent_names):
7382
self.__children_layers__.append((layer, parent_names))
@@ -78,14 +87,20 @@ def to_proto(self, context):
7887
"""
7988
self.__context__ = context
8089

81-
# short cut if myself is parsed before.
90+
# STEP: short cut if this layer is parsed before.
8291
if self.context_name() in context:
8392
if self.use_context_name():
8493
return context[self.context_name()]
8594
else:
8695
return context[self.name]
8796

88-
# parse parent before myself
97+
# STEP: parse extra_parent that is not used by this layer but must
98+
# be parsed before this layer.
99+
for p in self.__extra_parent__:
100+
p.to_proto(context=context)
101+
102+
# STEP: parse parent that is used by this layer, get the result and
103+
# insert into kwargs of the next layer's to_proto_impl method.
89104
kwargs = dict()
90105
for layer_name in self.__parent_layers__:
91106
if not isinstance(self.__parent_layers__[layer_name],
@@ -97,14 +112,13 @@ def to_proto(self, context):
97112
self.__parent_layers__[layer_name])
98113
kwargs[layer_name] = v1_layer
99114

100-
# parse myself.
115+
# STEP: parse myself and add myself into context.
101116
ret_val = self.to_proto_impl(**kwargs)
102-
103-
if self.context_name() is not None and \
104-
self.context_name() not in context:
117+
if self.context_name() is not None \
118+
and self.context_name() not in context:
105119
context[self.context_name()] = ret_val
106120

107-
# parse children.
121+
# STEP: parse children that should be pased after this layer.
108122
for layer, pnames in self.__children_layers__:
109123
drop = False
110124

@@ -117,6 +131,7 @@ def to_proto(self, context):
117131
continue
118132
layer.to_proto(context=context)
119133

134+
# STEP: return v1 layer result
120135
if self.context_name() is None:
121136
return ret_val
122137
elif self.use_context_name():

0 commit comments

Comments
 (0)