Skip to content

Commit 2f344e7

Browse files
committed
fix name convention.
1 parent 9a97c7f commit 2f344e7

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

python/paddle/v2/dataset/wmt14.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
UNK_IDX = 2
5151

5252

53-
def __read_to_dict__(tar_file, dict_size):
54-
def __to_dict__(fd, size):
53+
def __read_to_dict(tar_file, dict_size):
54+
def __to_dict(fd, size):
5555
out_dict = dict()
5656
for line_count, line in enumerate(fd):
5757
if line_count < size:
@@ -66,19 +66,19 @@ def __to_dict__(fd, size):
6666
if each_item.name.endswith("src.dict")
6767
]
6868
assert len(names) == 1
69-
src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
69+
src_dict = __to_dict(f.extractfile(names[0]), dict_size)
7070
names = [
7171
each_item.name for each_item in f
7272
if each_item.name.endswith("trg.dict")
7373
]
7474
assert len(names) == 1
75-
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
75+
trg_dict = __to_dict(f.extractfile(names[0]), dict_size)
7676
return src_dict, trg_dict
7777

7878

7979
def reader_creator(tar_file, file_name, dict_size):
8080
def reader():
81-
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
81+
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
8282
with tarfile.open(tar_file, mode='r') as f:
8383
names = [
8484
each_item.name for each_item in f
@@ -160,7 +160,7 @@ def get_dict(dict_size, reverse=True):
160160
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
161161
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
162162
tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
163-
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
163+
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
164164
if reverse:
165165
src_dict = {v: k for k, v in src_dict.items()}
166166
trg_dict = {v: k for k, v in trg_dict.items()}

python/paddle/v2/dataset/wmt16.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""
15-
ACL2016 Multimodal Machine Translation. Please see this websit for more details:
16-
http://www.statmt.org/wmt16/multimodal-task.html#task1
15+
ACL2016 Multimodal Machine Translation. Please see this website for more
16+
details: http://www.statmt.org/wmt16/multimodal-task.html#task1
1717
1818
If you use the dataset created for your task, please cite the following paper:
1919
Multi30K: Multilingual English-German Image Descriptions.
@@ -56,7 +56,7 @@
5656
UNK_MARK = "<unk>"
5757

5858

59-
def __build_dict__(tar_file, dict_size, save_path, lang):
59+
def __build_dict(tar_file, dict_size, save_path, lang):
6060
word_dict = defaultdict(int)
6161
with tarfile.open(tar_file, mode="r") as f:
6262
for line in f.extractfile("wmt16/train"):
@@ -75,12 +75,12 @@ def __build_dict__(tar_file, dict_size, save_path, lang):
7575
fout.write("%s\n" % (word[0]))
7676

7777

78-
def __load_dict__(tar_file, dict_size, lang, reverse=False):
78+
def __load_dict(tar_file, dict_size, lang, reverse=False):
7979
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
8080
"wmt16/%s_%d.dict" % (lang, dict_size))
8181
if not os.path.exists(dict_path) or (
8282
len(open(dict_path, "r").readlines()) != dict_size):
83-
__build_dict__(tar_file, dict_size, dict_path, lang)
83+
__build_dict(tar_file, dict_size, dict_path, lang)
8484

8585
word_dict = {}
8686
with open(dict_path, "r") as fdict:
@@ -92,7 +92,7 @@ def __load_dict__(tar_file, dict_size, lang, reverse=False):
9292
return word_dict
9393

9494

95-
def __get_dict_size__(src_dict_size, trg_dict_size, src_lang):
95+
def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
9696
src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else
9797
TOTAL_DE_WORDS))
9898
trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else
@@ -102,9 +102,9 @@ def __get_dict_size__(src_dict_size, trg_dict_size, src_lang):
102102

103103
def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
104104
def reader():
105-
src_dict = __load_dict__(tar_file, src_dict_size, src_lang)
106-
trg_dict = __load_dict__(tar_file, trg_dict_size,
107-
("de" if src_lang == "en" else "en"))
105+
src_dict = __load_dict(tar_file, src_dict_size, src_lang)
106+
trg_dict = __load_dict(tar_file, trg_dict_size,
107+
("de" if src_lang == "en" else "en"))
108108

109109
# the indice for start mark, end mark, and unk are the same in source
110110
# language and target language. Here uses the source language
@@ -173,8 +173,8 @@ def train(src_dict_size, trg_dict_size, src_lang="en"):
173173

174174
assert (src_lang in ["en", "de"], ("An error language type. Only support: "
175175
"en (for English); de(for Germany)"))
176-
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size,
177-
trg_dict_size, src_lang)
176+
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
177+
src_lang)
178178

179179
return reader_creator(
180180
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
@@ -222,8 +222,8 @@ def test(src_dict_size, trg_dict_size, src_lang="en"):
222222
("An error language type. "
223223
"Only support: en (for English); de(for Germany)"))
224224

225-
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size,
226-
trg_dict_size, src_lang)
225+
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
226+
src_lang)
227227

228228
return reader_creator(
229229
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
@@ -269,8 +269,8 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"):
269269
assert (src_lang in ["en", "de"],
270270
("An error language type. "
271271
"Only support: en (for English); de(for Germany)"))
272-
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size,
273-
trg_dict_size, src_lang)
272+
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
273+
src_lang)
274274

275275
return reader_creator(
276276
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
@@ -308,7 +308,7 @@ def get_dict(lang, dict_size, reverse=False):
308308
"Please invoke paddle.dataset.wmt16.train/test/validation "
309309
"first to build the dictionary.")
310310
tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz")
311-
return __load_dict__(tar_file, dict_size, lang, reverse)
311+
return __load_dict(tar_file, dict_size, lang, reverse)
312312

313313

314314
def fetch():

0 commit comments

Comments
 (0)