Skip to content

Commit a42233c

Browse files
committed
add wmt14 trg_dict
1 parent caffcc8 commit a42233c

File tree

2 files changed

+53
-41
lines changed

2 files changed

+53
-41
lines changed

demo/seqToseq/api_train_v2.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -126,51 +126,57 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
126126

127127
def main():
128128
paddle.init(use_gpu=False, trainer_count=1)
129+
is_generating = True
129130

130131
# source and target dict dim.
131132
dict_size = 30000
132133
source_dict_dim = target_dict_dim = dict_size
133134

134-
# define network topology
135-
cost = seqToseq_net(source_dict_dim, target_dict_dim)
136-
parameters = paddle.parameters.create(cost)
137-
138-
# define optimize method and trainer
139-
optimizer = paddle.optimizer.Adam(
140-
learning_rate=5e-5,
141-
regularization=paddle.optimizer.L2Regularization(rate=1e-3))
142-
trainer = paddle.trainer.SGD(cost=cost,
143-
parameters=parameters,
144-
update_equation=optimizer)
145-
146-
# define data reader
147-
feeding = {
148-
'source_language_word': 0,
149-
'target_language_word': 1,
150-
'target_language_next_word': 2
151-
}
152-
153-
wmt14_reader = paddle.batch(
154-
paddle.reader.shuffle(
155-
paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192),
156-
batch_size=5)
157-
158-
# define event_handler callback
159-
def event_handler(event):
160-
if isinstance(event, paddle.event.EndIteration):
161-
if event.batch_id % 10 == 0:
162-
print "\nPass %d, Batch %d, Cost %f, %s" % (
163-
event.pass_id, event.batch_id, event.cost, event.metrics)
164-
else:
165-
sys.stdout.write('.')
166-
sys.stdout.flush()
167-
168-
# start to train
169-
trainer.train(
170-
reader=wmt14_reader,
171-
event_handler=event_handler,
172-
num_passes=10000,
173-
feeding=feeding)
135+
# train the network
136+
if not is_generating:
137+
cost = seqToseq_net(source_dict_dim, target_dict_dim)
138+
parameters = paddle.parameters.create(cost)
139+
140+
# define optimize method and trainer
141+
optimizer = paddle.optimizer.Adam(
142+
learning_rate=5e-5,
143+
regularization=paddle.optimizer.L2Regularization(rate=8e-4))
144+
trainer = paddle.trainer.SGD(cost=cost,
145+
parameters=parameters,
146+
update_equation=optimizer)
147+
# define data reader
148+
wmt14_reader = paddle.batch(
149+
paddle.reader.shuffle(
150+
paddle.dataset.wmt14.train(dict_size), buf_size=8192),
151+
batch_size=5)
152+
153+
# define event_handler callback
154+
def event_handler(event):
155+
if isinstance(event, paddle.event.EndIteration):
156+
if event.batch_id % 10 == 0:
157+
print "\nPass %d, Batch %d, Cost %f, %s" % (
158+
event.pass_id, event.batch_id, event.cost,
159+
event.metrics)
160+
else:
161+
sys.stdout.write('.')
162+
sys.stdout.flush()
163+
164+
# start to train
165+
trainer.train(
166+
reader=wmt14_reader, event_handler=event_handler, num_passes=2)
167+
168+
# generate a english sequence to french
169+
else:
170+
gen_creator = paddle.dataset.wmt14.test(dict_size)
171+
gen_data = []
172+
for item in gen_creator():
173+
gen_data.append((item[0], ))
174+
if len(gen_data) == 3:
175+
break
176+
177+
beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating)
178+
parameters = paddle.dataset.wmt14.model()
179+
trg_dict = paddle.dataset.wmt14.trg_dict(dict_size)
174180

175181

176182
if __name__ == '__main__':

python/paddle/v2/dataset/wmt14.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
3030
# this is the pretrained model, whose bleu = 26.92
3131
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
32-
MD5_MODEL = '6b097d23e15654608c6f74923e975535'
32+
MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4'
3333

3434
START = "<s>"
3535
END = "<e>"
@@ -115,6 +115,12 @@ def model():
115115
return parameters
116116

117117

118+
def trg_dict(dict_size):
119+
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN)
120+
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
121+
return trg_dict
122+
123+
118124
def fetch():
119125
download(URL_TRAIN, 'wmt14', MD5_TRAIN)
120126
download(URL_MODEL, 'wmt14', MD5_MODEL)

0 commit comments

Comments
 (0)