Skip to content

Commit d6c6a99

Browse files
author
Helin Wang
committed
dataset reader for wmt14
example usage: import paddle.v2 as paddle if __name__ == '__main__': dict_en, dict_fr = paddle.dataset.wmt14.build_dict() train = paddle.dataset.wmt14.train(dict_en, dict_fr) test = paddle.dataset.wmt14.test(dict_en, dict_fr) total_train = 0 for i in train(): total_train += 1 total_test = 0 for i in test(): total_test += 1 print total_train, total_test
1 parent 349e799 commit d6c6a99

File tree

2 files changed

+144
-1
lines changed

2 files changed

+144
-1
lines changed

python/paddle/v2/dataset/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import conll05
2121
import uci_housing
2222
import sentiment
23+
import wmt14
2324

2425
__all__ = [
2526
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
26-
'uci_housing'
27+
'uci_housing', 'wmt14'
2728
]

python/paddle/v2/dataset/wmt14.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
wmt14 dataset
16+
"""
17+
import paddle.v2.dataset.common
18+
import tarfile
19+
import os.path
20+
import itertools
21+
22+
__all__ = ['train', 'test', 'build_dict']
23+
24+
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
25+
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
26+
URL_TRAIN = 'http://localhost:8000/train.tgz'
27+
MD5_TRAIN = '72de99da2830ea5a3a2c4eb36092bbc7'
28+
29+
30+
def word_count(f, word_freq=None):
31+
add = paddle.v2.dataset.common.dict_add
32+
if word_freq == None:
33+
word_freq = {}
34+
35+
for l in f:
36+
for w in l.strip().split():
37+
add(word_freq, w)
38+
add(word_freq, '<s>')
39+
add(word_freq, '<e>')
40+
41+
return word_freq
42+
43+
44+
def get_word_dix(word_freq):
45+
TYPO_FREQ = 50
46+
word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items())
47+
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
48+
words, _ = list(zip(*word_freq_sorted))
49+
word_idx = dict(zip(words, xrange(len(words))))
50+
word_idx['<unk>'] = len(words)
51+
return word_idx
52+
53+
54+
def get_word_freq(train, dev):
55+
word_freq = word_count(train, word_count(dev))
56+
if '<unk>' in word_freq:
57+
# remove <unk> for now, since we will set it as last index
58+
del word_freq['<unk>']
59+
return word_freq
60+
61+
62+
def build_dict():
63+
base_dir = './wmt14-data'
64+
train_en_filename = base_dir + '/train/train.en'
65+
train_fr_filename = base_dir + '/train/train.fr'
66+
dev_en_filename = base_dir + '/dev/ntst1213.en'
67+
dev_fr_filename = base_dir + '/dev/ntst1213.fr'
68+
69+
if not os.path.exists(train_en_filename) or not os.path.exists(
70+
train_fr_filename):
71+
with tarfile.open(
72+
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14',
73+
MD5_TRAIN)) as tf:
74+
tf.extractall(base_dir)
75+
76+
if not os.path.exists(dev_en_filename) or not os.path.exists(
77+
dev_fr_filename):
78+
with tarfile.open(
79+
paddle.v2.dataset.common.download(URL_DEV_TEST, 'wmt14',
80+
MD5_DEV_TEST)) as tf:
81+
tf.extractall(base_dir)
82+
83+
f_en = open(train_en_filename)
84+
f_fr = open(train_fr_filename)
85+
f_en_dev = open(dev_en_filename)
86+
f_fr_dev = open(dev_fr_filename)
87+
88+
word_freq_en = get_word_freq(f_en, f_en_dev)
89+
word_freq_fr = get_word_freq(f_fr, f_fr_dev)
90+
91+
f_en.close()
92+
f_fr.close()
93+
f_en_dev.close()
94+
f_fr_dev.close()
95+
96+
return get_word_dix(word_freq_en), get_word_dix(word_freq_fr)
97+
98+
99+
def reader_creator(directory, path_en, path_fr, URL, MD5, dict_en, dict_fr):
100+
def reader():
101+
if not os.path.exists(path_en) or not os.path.exists(path_fr):
102+
with tarfile.open(
103+
paddle.v2.dataset.common.download(URL, 'wmt14', MD5)) as tf:
104+
tf.extractall(directory)
105+
106+
f_en = open(path_en)
107+
f_fr = open(path_fr)
108+
UNK_en = dict_en['<unk>']
109+
UNK_fr = dict_fr['<unk>']
110+
111+
for en, fr in itertools.izip(f_en, f_fr):
112+
src_ids = [dict_en.get(w, UNK_en) for w in en.strip().split()]
113+
tar_ids = [
114+
dict_fr.get(w, UNK_fr)
115+
for w in ['<s>'] + fr.strip().split() + ['<e>']
116+
]
117+
118+
# remove sequence whose length > 80 in training mode
119+
if len(src_ids) == 0 or len(tar_ids) <= 1 or len(
120+
src_ids) > 80 or len(tar_ids) > 80:
121+
continue
122+
123+
yield src_ids, tar_ids[:-1], tar_ids[1:]
124+
125+
f_en.close()
126+
f_fr.close()
127+
128+
return reader
129+
130+
131+
def train(dict_en, dict_fr):
132+
directory = './wmt14-data'
133+
return reader_creator(directory, directory + '/train/train.en',
134+
directory + '/train/train.fr', URL_TRAIN, MD5_TRAIN,
135+
dict_en, dict_fr)
136+
137+
138+
def test(dict_en, dict_fr):
139+
directory = './wmt14-data'
140+
return reader_creator(directory, directory + '/dev/ntst1213.en',
141+
directory + '/dev/ntst1213.fr', URL_DEV_TEST,
142+
MD5_DEV_TEST, dict_en, dict_fr)

0 commit comments

Comments
 (0)