Skip to content

Commit 04876d0

Browse files
authored
Merge pull request #98 from luotao1/beam
update beam_search and seqToseq config, and add ExpActivation api
2 parents 425e5b0 + d2e1b46 commit 04876d0

File tree

4 files changed

+68
-99
lines changed

4 files changed

+68
-99
lines changed

demo/seqToseq/seqToseq_net.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,16 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
128128
return out
129129

130130
decoder_group_name = "decoder_group"
131+
group_inputs=[StaticInput(input=encoded_vector,is_seq=True),
132+
StaticInput(input=encoded_proj,is_seq=True)]
133+
131134
if not is_generating:
132135
trg_embedding = embedding_layer(
133136
input=data_layer(name='target_language_word',
134137
size=target_dict_dim),
135138
size=word_vector_dim,
136139
param_attr=ParamAttr(name='_target_language_embedding'))
140+
group_inputs.append(trg_embedding)
137141

138142
# For decoder equipped with attention mechanism, in training,
139143
# target embeding (the groudtruth) is the data input,
@@ -142,22 +146,13 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
142146
# for the recurrent_group.
143147
decoder = recurrent_group(name=decoder_group_name,
144148
step=gru_decoder_with_attention,
145-
input=[
146-
StaticInput(input=encoded_vector,
147-
is_seq=True),
148-
StaticInput(input=encoded_proj,
149-
is_seq=True), trg_embedding
150-
])
149+
input=group_inputs)
151150

152151
lbl = data_layer(name='target_language_next_word',
153152
size=target_dict_dim)
154-
cost = classification_cost(input=decoder, label=lbl, )
153+
cost = classification_cost(input=decoder, label=lbl)
155154
outputs(cost)
156155
else:
157-
gen_inputs = [StaticInput(input=encoded_vector,
158-
is_seq=True),
159-
StaticInput(input=encoded_proj,
160-
is_seq=True), ]
161156
# In generation, the decoder predicts a next target word based on
162157
# the encoded source sequence and the last generated target word.
163158

@@ -171,10 +166,11 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
171166
size=target_dict_dim,
172167
embedding_name='_target_language_embedding',
173168
embedding_size=word_vector_dim)
174-
gen_inputs.append(trg_embedding)
169+
group_inputs.append(trg_embedding)
170+
175171
beam_gen = beam_search(name=decoder_group_name,
176172
step=gru_decoder_with_attention,
177-
input=gen_inputs,
173+
input=group_inputs,
178174
id_input=data_layer(name="sent_id",
179175
size=1),
180176
dict_file=trg_dict_path,

doc/ui/api/trainer_config_helpers/activations.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ AbsActivation
1212
:members: AbsActivation
1313
:noindex:
1414

15+
ExpActivation
16+
===============
17+
18+
.. automodule:: paddle.trainer_config_helpers.activations
19+
:members: ExpActivation
20+
:noindex:
21+
1522
IdentityActivation
1623
==================
1724

paddle/trainer/tests/sample_trainer_rnn_gen.conf

Lines changed: 42 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,96 +13,53 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
#Todo(luotao02) This config is only used for unitest. It is out of date now, and will be updated later.
1716

18-
import math
17+
from paddle.trainer_config_helpers import *
1918

20-
beam_search = get_config_arg('beam_search', bool, False)
21-
22-
model_type("recurrent_nn")
23-
24-
Settings(learning_rate=0, batch_size=15, algorithm='sgd')
25-
26-
Inputs("sent_id", "dummy_data_input")
27-
Outputs("predict_word")
19+
settings(batch_size=15, learning_rate=0)
2820

2921
num_words = 5
22+
beam_flag = get_config_arg('beam_search', bool, False)
3023

31-
DataLayer(name="sent_id", size=1, )
24+
sent_id = data_layer(name="sent_id", size=1)
3225

3326
# This layer has no actual use, but only to decide batch_size in generation.
3427
# When generating, at least one Memory in RecurrentLayer MUST have a boot layer.
35-
DataLayer(name="dummy_data_input", size=2, )
36-
37-
if beam_search:
38-
RecurrentLayerGroupBegin("decoding_layer_group",
39-
in_links=[],
40-
out_links=["predict_word"],
41-
generator=Generator(max_num_frames=10,
42-
beam_size=2,
43-
num_results_per_sample=2, ))
44-
else:
45-
RecurrentLayerGroupBegin("decoding_layer_group",
46-
in_links=[],
47-
out_links=["predict_word"],
48-
generator=Generator(max_num_frames=10, ))
49-
dummy_memory = Memory(name="dummy_memory",
50-
size=2,
51-
boot_layer="dummy_data_input")
52-
MixedLayer(name="dummy_memory",
53-
size=2,
54-
bias=False,
55-
inputs=[IdentityProjection(dummy_memory)], )
56-
state_memory = Memory(name="state",
57-
size=num_words,
58-
#boot_bias=True,
59-
#boot_bias_active_type = "tanh",
60-
)
61-
62-
predict_word_memory = Memory(name="predict_word",
63-
size=num_words,
64-
boot_with_const_id=0, )
65-
66-
MixedLayer(
67-
name = "word_embedding",
68-
size = num_words, # word embedding dim is the same as num_words in this test.
69-
bias = False,
70-
inputs = TableProjection(predict_word_memory,
71-
initial_std=1,
72-
learning_rate=0,
73-
parameter_name="wordvec"))
74-
75-
Layer( # simplified RNN for testing
76-
name="state",
77-
type="mixed",
78-
size=num_words,
79-
bias=False,
80-
inputs=[FullMatrixProjection("word_embedding",
81-
parameter_name="transtable")])
82-
83-
Layer(name="output",
84-
type="mixed",
85-
size=num_words,
86-
active_type="exponential",
87-
bias=False,
88-
inputs=TransposedFullMatrixProjection("state",
89-
initial_std=1,
90-
learning_rate=0,
91-
parameter_name="wordvec"), )
92-
93-
Layer(name="predict_word", type="maxid", inputs=["output"], )
94-
95-
Layer(name="eos_check",
96-
type="eos_id",
97-
eos_id=num_words - 1,
98-
inputs=["predict_word"], )
99-
RecurrentLayerGroupEnd("decoding_layer_group")
100-
101-
Evaluator(name="answer_printer",
102-
type="seq_text_printer",
103-
dict_file="./trainer/tests/test_gen_dict.txt",
104-
result_file="./trainer/tests/dump_text.test",
105-
inputs=[
106-
"sent_id",
107-
"predict_word",
108-
], )
28+
dummy_data = data_layer(name="dummy_data_input", size=2)
29+
30+
gen_inputs = [StaticInput(input=dummy_data, size=2),
31+
GeneratedInput(size=num_words,
32+
embedding_name="wordvec",
33+
embedding_size=num_words)]
34+
35+
def step(dummy_memory, predict_word):
36+
37+
# simplified RNN for testing
38+
with mixed_layer(size=num_words) as layer:
39+
layer += full_matrix_projection(input=predict_word,
40+
param_attr=ParamAttr(name="transtable"))
41+
42+
with mixed_layer(size=num_words, act=ExpActivation()) as out:
43+
out += trans_full_matrix_projection(input=layer,
44+
param_attr=ParamAttr(name="wordvec"))
45+
46+
return out
47+
48+
beam_gen = beam_search(name="rnn_gen",
49+
step=step,
50+
input=gen_inputs,
51+
id_input=sent_id,
52+
dict_file="./trainer/tests/test_gen_dict.txt",
53+
result_file="./trainer/tests/dump_text.test",
54+
bos_id=0,
55+
eos_id=num_words-1,
56+
beam_size=2 if beam_flag else 1,
57+
num_results_per_sample=2 if beam_flag else 1,
58+
max_length=10)
59+
60+
#outputs(beam_gen)
61+
# In this config, as dummy_data_input doesn't work on beam_gen (we can find dummy_memory
62+
# is read-only memory, and isn't used by other layers of step), we show the Inputs and Outputs
63+
# as follows. Note that "__beam_search_predict__" is the default output name of beam_search.
64+
Inputs("sent_id","dummy_data_input")
65+
Outputs("__beam_search_predict__")

python/paddle/trainer_config_helpers/activations.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
__all__ = ["TanhActivation", "SigmoidActivation",
1616
"SoftmaxActivation", "IdentityActivation", "LinearActivation",
17-
'SequenceSoftmaxActivation',
17+
'SequenceSoftmaxActivation', 'ExpActivation',
1818
"ReluActivation", "BReluActivation", "SoftReluActivation", "STanhActivation",
1919
"AbsActivation", "SquareActivation", "BaseActivation"]
2020

@@ -185,3 +185,12 @@ class SquareActivation(BaseActivation):
185185
"""
186186

187187
def __init__(self): BaseActivation.__init__(self, 'square', False)
188+
189+
class ExpActivation(BaseActivation):
190+
"""
191+
Exponential Activation.
192+
193+
.. math::
194+
f(z) = e^z.
195+
"""
196+
def __init__(self): BaseActivation.__init__(self, 'exponential', False)

0 commit comments

Comments
 (0)