@@ -126,51 +126,57 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
126
126
127
127
def main ():
128
128
paddle .init (use_gpu = False , trainer_count = 1 )
129
+ is_generating = True
129
130
130
131
# source and target dict dim.
131
132
dict_size = 30000
132
133
source_dict_dim = target_dict_dim = dict_size
133
134
134
- # define network topology
135
- cost = seqToseq_net (source_dict_dim , target_dict_dim )
136
- parameters = paddle .parameters .create (cost )
137
-
138
- # define optimize method and trainer
139
- optimizer = paddle .optimizer .Adam (
140
- learning_rate = 5e-5 ,
141
- regularization = paddle .optimizer .L2Regularization (rate = 1e-3 ))
142
- trainer = paddle .trainer .SGD (cost = cost ,
143
- parameters = parameters ,
144
- update_equation = optimizer )
145
-
146
- # define data reader
147
- feeding = {
148
- 'source_language_word' : 0 ,
149
- 'target_language_word' : 1 ,
150
- 'target_language_next_word' : 2
151
- }
152
-
153
- wmt14_reader = paddle .batch (
154
- paddle .reader .shuffle (
155
- paddle .dataset .wmt14 .train (dict_size = dict_size ), buf_size = 8192 ),
156
- batch_size = 5 )
157
-
158
- # define event_handler callback
159
- def event_handler (event ):
160
- if isinstance (event , paddle .event .EndIteration ):
161
- if event .batch_id % 10 == 0 :
162
- print "\n Pass %d, Batch %d, Cost %f, %s" % (
163
- event .pass_id , event .batch_id , event .cost , event .metrics )
164
- else :
165
- sys .stdout .write ('.' )
166
- sys .stdout .flush ()
167
-
168
- # start to train
169
- trainer .train (
170
- reader = wmt14_reader ,
171
- event_handler = event_handler ,
172
- num_passes = 10000 ,
173
- feeding = feeding )
135
+ # train the network
136
+ if not is_generating :
137
+ cost = seqToseq_net (source_dict_dim , target_dict_dim )
138
+ parameters = paddle .parameters .create (cost )
139
+
140
+ # define optimize method and trainer
141
+ optimizer = paddle .optimizer .Adam (
142
+ learning_rate = 5e-5 ,
143
+ regularization = paddle .optimizer .L2Regularization (rate = 8e-4 ))
144
+ trainer = paddle .trainer .SGD (cost = cost ,
145
+ parameters = parameters ,
146
+ update_equation = optimizer )
147
+ # define data reader
148
+ wmt14_reader = paddle .batch (
149
+ paddle .reader .shuffle (
150
+ paddle .dataset .wmt14 .train (dict_size ), buf_size = 8192 ),
151
+ batch_size = 5 )
152
+
153
+ # define event_handler callback
154
+ def event_handler (event ):
155
+ if isinstance (event , paddle .event .EndIteration ):
156
+ if event .batch_id % 10 == 0 :
157
+ print "\n Pass %d, Batch %d, Cost %f, %s" % (
158
+ event .pass_id , event .batch_id , event .cost ,
159
+ event .metrics )
160
+ else :
161
+ sys .stdout .write ('.' )
162
+ sys .stdout .flush ()
163
+
164
+ # start to train
165
+ trainer .train (
166
+ reader = wmt14_reader , event_handler = event_handler , num_passes = 2 )
167
+
168
+ # generate a english sequence to french
169
+ else :
170
+ gen_creator = paddle .dataset .wmt14 .test (dict_size )
171
+ gen_data = []
172
+ for item in gen_creator ():
173
+ gen_data .append ((item [0 ], ))
174
+ if len (gen_data ) == 3 :
175
+ break
176
+
177
+ beam_gen = seqToseq_net (source_dict_dim , target_dict_dim , is_generating )
178
+ parameters = paddle .dataset .wmt14 .model ()
179
+ trg_dict = paddle .dataset .wmt14 .trg_dict (dict_size )
174
180
175
181
176
182
if __name__ == '__main__' :
0 commit comments