Skip to content

Commit cf7f651

Browse files
committed
add wmt14 pretrained model
1 parent cd81075 commit cf7f651

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

python/paddle/v2/dataset/wmt14.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
wmt14 dataset
1616
"""
1717
import tarfile
18+
import gzip
1819

1920
from paddle.v2.dataset.common import download
21+
from paddle.v2.parameters import Parameters
2022

2123
__all__ = ['train', 'test', 'build_dict']
2224

@@ -25,6 +27,9 @@
2527
# this is a small set of data for test. The original data is too large and will be add later.
2628
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
2729
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
30+
# this is the pretrained model, whose bleu = 26.92
31+
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
32+
MD5_MODEL = '6b097d23e15654608c6f74923e975535'
2833

2934
START = "<s>"
3035
END = "<e>"
@@ -103,5 +108,13 @@ def test(dict_size):
103108
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
104109

105110

111+
def model():
112+
tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
113+
with gzip.open(tar_file, 'r') as f:
114+
parameters = Parameters.from_tar(f)
115+
return parameters
116+
117+
106118
def fetch():
107119
download(URL_TRAIN, 'wmt14', MD5_TRAIN)
120+
download(URL_MODEL, 'wmt14', MD5_MODEL)

0 commit comments

Comments
 (0)