Skip to content

Commit 8fa09b8

Browse files
committed
add some comment for api_train_v2 of seqtoseq
1 parent c5fb4fd commit 8fa09b8

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

demo/seqToseq/api_train_v2.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,31 +73,34 @@ def main():
7373
cost = seqToseq_net_v2(source_dict_dim, target_dict_dim)
7474
parameters = paddle.parameters.create(cost)
7575

76+
# define optimize method and trainer
7677
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
77-
78-
def event_handler(event):
79-
if isinstance(event, paddle.event.EndIteration):
80-
if event.batch_id % 10 == 0:
81-
print "Pass %d, Batch %d, Cost %f, %s" % (
82-
event.pass_id, event.batch_id, event.cost, event.metrics)
83-
8478
trainer = paddle.trainer.SGD(cost=cost,
8579
parameters=parameters,
8680
update_equation=optimizer)
8781

82+
# define data reader
8883
reader_dict = {
8984
'source_language_word': 0,
9085
'target_language_word': 1,
9186
'target_language_next_word': 2
9287
}
9388

94-
trn_reader = paddle.reader.batched(
89+
wmt14_reader = paddle.reader.batched(
9590
paddle.reader.shuffle(
9691
train_reader("data/pre-wmt14/train/train"), buf_size=8192),
9792
batch_size=5)
9893

94+
# define event_handler callback
95+
def event_handler(event):
96+
if isinstance(event, paddle.event.EndIteration):
97+
if event.batch_id % 10 == 0:
98+
print "Pass %d, Batch %d, Cost %f, %s" % (
99+
event.pass_id, event.batch_id, event.cost, event.metrics)
100+
101+
# start to train
99102
trainer.train(
100-
reader=trn_reader,
103+
reader=wmt14_reader,
101104
event_handler=event_handler,
102105
num_passes=10000,
103106
reader_dict=reader_dict)

0 commit comments

Comments
 (0)