@@ -73,31 +73,34 @@ def main():
73
73
cost = seqToseq_net_v2 (source_dict_dim , target_dict_dim )
74
74
parameters = paddle .parameters .create (cost )
75
75
76
+ # define optimize method and trainer
76
77
optimizer = paddle .optimizer .Adam (learning_rate = 1e-4 )
77
-
78
- def event_handler (event ):
79
- if isinstance (event , paddle .event .EndIteration ):
80
- if event .batch_id % 10 == 0 :
81
- print "Pass %d, Batch %d, Cost %f, %s" % (
82
- event .pass_id , event .batch_id , event .cost , event .metrics )
83
-
84
78
trainer = paddle .trainer .SGD (cost = cost ,
85
79
parameters = parameters ,
86
80
update_equation = optimizer )
87
81
82
+ # define data reader
88
83
reader_dict = {
89
84
'source_language_word' : 0 ,
90
85
'target_language_word' : 1 ,
91
86
'target_language_next_word' : 2
92
87
}
93
88
94
- trn_reader = paddle .reader .batched (
89
+ wmt14_reader = paddle .reader .batched (
95
90
paddle .reader .shuffle (
96
91
train_reader ("data/pre-wmt14/train/train" ), buf_size = 8192 ),
97
92
batch_size = 5 )
98
93
94
+ # define event_handler callback
95
+ def event_handler (event ):
96
+ if isinstance (event , paddle .event .EndIteration ):
97
+ if event .batch_id % 10 == 0 :
98
+ print "Pass %d, Batch %d, Cost %f, %s" % (
99
+ event .pass_id , event .batch_id , event .cost , event .metrics )
100
+
101
+ # start to train
99
102
trainer .train (
100
- reader = trn_reader ,
103
+ reader = wmt14_reader ,
101
104
event_handler = event_handler ,
102
105
num_passes = 10000 ,
103
106
reader_dict = reader_dict )
0 commit comments