Skip to content

Commit 430fdc5

Browse files
authored
Merge pull request #7661 from lcy-seso/wmt16_en_ger
Add WMT16 dataset.
2 parents a6da470 + 2f344e7 commit 430fdc5

File tree

6 files changed

+488
-26
lines changed

6 files changed

+488
-26
lines changed

python/paddle/v2/dataset/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,23 @@
2424
import uci_housing
2525
import sentiment
2626
import wmt14
27+
import wmt16
2728
import mq2007
2829
import flowers
2930
import voc2012
3031

3132
__all__ = [
32-
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
33-
'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012'
33+
'mnist',
34+
'imikolov',
35+
'imdb',
36+
'cifar',
37+
'movielens',
38+
'conll05',
39+
'sentiment'
40+
'uci_housing',
41+
'wmt14',
42+
'wmt16',
43+
'mq2007',
44+
'flowers',
45+
'voc2012',
3446
]

python/paddle/v2/dataset/common.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525
import cPickle as pickle
2626

2727
__all__ = [
28-
'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader',
29-
'convert'
28+
'DATA_HOME',
29+
'download',
30+
'md5file',
31+
'split',
32+
'cluster_files_reader',
33+
'convert',
3034
]
3135

3236
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
@@ -58,12 +62,15 @@ def md5file(fname):
5862
return hash_md5.hexdigest()
5963

6064

61-
def download(url, module_name, md5sum):
65+
def download(url, module_name, md5sum, save_name=None):
6266
dirname = os.path.join(DATA_HOME, module_name)
6367
if not os.path.exists(dirname):
6468
os.makedirs(dirname)
6569

66-
filename = os.path.join(dirname, url.split('/')[-1])
70+
filename = os.path.join(dirname,
71+
url.split('/')[-1]
72+
if save_name is None else save_name)
73+
6774
retry = 0
6875
retry_limit = 3
6976
while not (os.path.exists(filename) and md5file(filename) == md5sum):
@@ -196,9 +203,11 @@ def convert(output_path, reader, line_count, name_prefix):
196203
Convert data from reader to recordio format files.
197204
198205
:param output_path: directory in which output files will be saved.
199-
:param reader: a data reader, from which the convert program will read data instances.
206+
:param reader: a data reader, from which the convert program will read
207+
data instances.
200208
:param name_prefix: the name prefix of generated files.
201-
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
209+
:param max_lines_to_shuffle: the max lines numbers to shuffle before
210+
writing.
202211
"""
203212

204213
assert line_count >= 1
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.v2.dataset.wmt16
16+
import unittest
17+
18+
19+
class TestWMT16(unittest.TestCase):
20+
def checkout_one_sample(self, sample):
21+
# train data has 3 field: source language word indices,
22+
# target language word indices, and target next word indices.
23+
self.assertEqual(len(sample), 3)
24+
25+
# test start mark and end mark in source word indices.
26+
self.assertEqual(sample[0][0], 0)
27+
self.assertEqual(sample[0][-1], 1)
28+
29+
# test start mask in target word indices
30+
self.assertEqual(sample[1][0], 0)
31+
32+
# test en mask in target next word indices
33+
self.assertEqual(sample[2][-1], 1)
34+
35+
def test_train(self):
36+
for idx, sample in enumerate(
37+
paddle.v2.dataset.wmt16.train(
38+
src_dict_size=100000, trg_dict_size=100000)()):
39+
if idx >= 10: break
40+
self.checkout_one_sample(sample)
41+
42+
def test_test(self):
43+
for idx, sample in enumerate(
44+
paddle.v2.dataset.wmt16.test(
45+
src_dict_size=1000, trg_dict_size=1000)()):
46+
if idx >= 10: break
47+
self.checkout_one_sample(sample)
48+
49+
def test_val(self):
50+
for idx, sample in enumerate(
51+
paddle.v2.dataset.wmt16.validation(
52+
src_dict_size=1000, trg_dict_size=1000)()):
53+
if idx >= 10: break
54+
self.checkout_one_sample(sample)
55+
56+
def test_get_dict(self):
57+
dict_size = 1000
58+
word_dict = paddle.v2.dataset.wmt16.get_dict("en", dict_size, True)
59+
self.assertEqual(len(word_dict), dict_size)
60+
self.assertEqual(word_dict[0], "<s>")
61+
self.assertEqual(word_dict[1], "<e>")
62+
self.assertEqual(word_dict[2], "<unk>")
63+
64+
65+
if __name__ == "__main__":
66+
unittest.main()

python/paddle/v2/dataset/wmt14.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,20 @@
2525
import paddle.v2.dataset.common
2626
from paddle.v2.parameters import Parameters
2727

28-
__all__ = ['train', 'test', 'build_dict', 'convert']
29-
30-
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
28+
__all__ = [
29+
'train',
30+
'test',
31+
'get_dict',
32+
'convert',
33+
]
34+
35+
URL_DEV_TEST = ('http://www-lium.univ-lemans.fr/~schwenk/'
36+
'cslm_joint_paper/data/dev+test.tgz')
3137
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
32-
# this is a small set of data for test. The original data is too large and will be add later.
33-
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
38+
# this is a small set of data for test. The original data is too large and
39+
# will be add later.
40+
URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
41+
'wmt_shrinked_data/wmt14.tgz')
3442
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
3543
# BLEU of this trained model is 26.92
3644
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
@@ -42,8 +50,8 @@
4250
UNK_IDX = 2
4351

4452

45-
def __read_to_dict__(tar_file, dict_size):
46-
def __to_dict__(fd, size):
53+
def __read_to_dict(tar_file, dict_size):
54+
def __to_dict(fd, size):
4755
out_dict = dict()
4856
for line_count, line in enumerate(fd):
4957
if line_count < size:
@@ -58,19 +66,19 @@ def __to_dict__(fd, size):
5866
if each_item.name.endswith("src.dict")
5967
]
6068
assert len(names) == 1
61-
src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
69+
src_dict = __to_dict(f.extractfile(names[0]), dict_size)
6270
names = [
6371
each_item.name for each_item in f
6472
if each_item.name.endswith("trg.dict")
6573
]
6674
assert len(names) == 1
67-
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
75+
trg_dict = __to_dict(f.extractfile(names[0]), dict_size)
6876
return src_dict, trg_dict
6977

7078

7179
def reader_creator(tar_file, file_name, dict_size):
7280
def reader():
73-
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
81+
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
7482
with tarfile.open(tar_file, mode='r') as f:
7583
names = [
7684
each_item.name for each_item in f
@@ -152,7 +160,7 @@ def get_dict(dict_size, reverse=True):
152160
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
153161
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
154162
tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
155-
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
163+
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
156164
if reverse:
157165
src_dict = {v: k for k, v in src_dict.items()}
158166
trg_dict = {v: k for k, v in trg_dict.items()}

0 commit comments

Comments
 (0)