12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
"""
15
- ACL2016 Multimodal Machine Translation. Please see this websit for more details:
16
- http://www.statmt.org/wmt16/multimodal-task.html#task1
15
+ ACL2016 Multimodal Machine Translation. Please see this website for more
16
+ details: http://www.statmt.org/wmt16/multimodal-task.html#task1
17
17
18
18
If you use the dataset created for your task, please cite the following paper:
19
19
Multi30K: Multilingual English-German Image Descriptions.
56
56
UNK_MARK = "<unk>"
57
57
58
58
59
- def __build_dict__ (tar_file , dict_size , save_path , lang ):
59
+ def __build_dict (tar_file , dict_size , save_path , lang ):
60
60
word_dict = defaultdict (int )
61
61
with tarfile .open (tar_file , mode = "r" ) as f :
62
62
for line in f .extractfile ("wmt16/train" ):
@@ -75,12 +75,12 @@ def __build_dict__(tar_file, dict_size, save_path, lang):
75
75
fout .write ("%s\n " % (word [0 ]))
76
76
77
77
78
- def __load_dict__ (tar_file , dict_size , lang , reverse = False ):
78
+ def __load_dict (tar_file , dict_size , lang , reverse = False ):
79
79
dict_path = os .path .join (paddle .v2 .dataset .common .DATA_HOME ,
80
80
"wmt16/%s_%d.dict" % (lang , dict_size ))
81
81
if not os .path .exists (dict_path ) or (
82
82
len (open (dict_path , "r" ).readlines ()) != dict_size ):
83
- __build_dict__ (tar_file , dict_size , dict_path , lang )
83
+ __build_dict (tar_file , dict_size , dict_path , lang )
84
84
85
85
word_dict = {}
86
86
with open (dict_path , "r" ) as fdict :
@@ -92,7 +92,7 @@ def __load_dict__(tar_file, dict_size, lang, reverse=False):
92
92
return word_dict
93
93
94
94
95
- def __get_dict_size__ (src_dict_size , trg_dict_size , src_lang ):
95
+ def __get_dict_size (src_dict_size , trg_dict_size , src_lang ):
96
96
src_dict_size = min (src_dict_size , (TOTAL_EN_WORDS if src_lang == "en" else
97
97
TOTAL_DE_WORDS ))
98
98
trg_dict_size = min (trg_dict_size , (TOTAL_DE_WORDS if src_lang == "en" else
@@ -102,9 +102,9 @@ def __get_dict_size__(src_dict_size, trg_dict_size, src_lang):
102
102
103
103
def reader_creator (tar_file , file_name , src_dict_size , trg_dict_size , src_lang ):
104
104
def reader ():
105
- src_dict = __load_dict__ (tar_file , src_dict_size , src_lang )
106
- trg_dict = __load_dict__ (tar_file , trg_dict_size ,
107
- ("de" if src_lang == "en" else "en" ))
105
+ src_dict = __load_dict (tar_file , src_dict_size , src_lang )
106
+ trg_dict = __load_dict (tar_file , trg_dict_size ,
107
+ ("de" if src_lang == "en" else "en" ))
108
108
109
109
# the indice for start mark, end mark, and unk are the same in source
110
110
# language and target language. Here uses the source language
@@ -173,8 +173,8 @@ def train(src_dict_size, trg_dict_size, src_lang="en"):
173
173
174
174
assert (src_lang in ["en" , "de" ], ("An error language type. Only support: "
175
175
"en (for English); de(for Germany)" ))
176
- src_dict_size , trg_dict_size = __get_dict_size__ (src_dict_size ,
177
- trg_dict_size , src_lang )
176
+ src_dict_size , trg_dict_size = __get_dict_size (src_dict_size , trg_dict_size ,
177
+ src_lang )
178
178
179
179
return reader_creator (
180
180
tar_file = paddle .v2 .dataset .common .download (DATA_URL , "wmt16" , DATA_MD5 ,
@@ -222,8 +222,8 @@ def test(src_dict_size, trg_dict_size, src_lang="en"):
222
222
("An error language type. "
223
223
"Only support: en (for English); de(for Germany)" ))
224
224
225
- src_dict_size , trg_dict_size = __get_dict_size__ (src_dict_size ,
226
- trg_dict_size , src_lang )
225
+ src_dict_size , trg_dict_size = __get_dict_size (src_dict_size , trg_dict_size ,
226
+ src_lang )
227
227
228
228
return reader_creator (
229
229
tar_file = paddle .v2 .dataset .common .download (DATA_URL , "wmt16" , DATA_MD5 ,
@@ -269,8 +269,8 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"):
269
269
assert (src_lang in ["en" , "de" ],
270
270
("An error language type. "
271
271
"Only support: en (for English); de(for Germany)" ))
272
- src_dict_size , trg_dict_size = __get_dict_size__ (src_dict_size ,
273
- trg_dict_size , src_lang )
272
+ src_dict_size , trg_dict_size = __get_dict_size (src_dict_size , trg_dict_size ,
273
+ src_lang )
274
274
275
275
return reader_creator (
276
276
tar_file = paddle .v2 .dataset .common .download (DATA_URL , "wmt16" , DATA_MD5 ,
@@ -308,7 +308,7 @@ def get_dict(lang, dict_size, reverse=False):
308
308
"Please invoke paddle.dataset.wmt16.train/test/validation "
309
309
"first to build the dictionary." )
310
310
tar_file = os .path .join (paddle .v2 .dataset .common .DATA_HOME , "wmt16.tar.gz" )
311
- return __load_dict__ (tar_file , dict_size , lang , reverse )
311
+ return __load_dict (tar_file , dict_size , lang , reverse )
312
312
313
313
314
314
def fetch ():
0 commit comments