Skip to content

Commit 2197402

Browse files
gongenleiLiuChiachi
andauthored
[BUGFIX] Cnn_dailymail and xnli raise error when downloading in multi-gpus mode (#1587)
* fix: multi-gpus count file_num * fix: update xnli Co-authored-by: Jiaqi Liu <[email protected]>
1 parent 1c10aba commit 2197402

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

docs/data_prepare/dataset_list.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ PaddleNLP提供了以下数据集的快速读取API,实际使用时请根据
5050
| [CLUEWSCF](https://github.com/CLUEbenchmark/FewCLUE/tree/main/datasets) | FewCLUE 评测中的 WSC Winograd 模式挑战中文版,代词消歧任务,二分类任务 | `paddlenlp.datasets.load_dataset('fewclue', 'cluewsc')`|
5151
| [THUCNews](https://github.com/gaussic/text-classification-cnn-rnn#%E6%95%B0%E6%8D%AE%E9%9B%86) | THUCNews中文新闻类别分类 | `paddlenlp.datasets.load_dataset('thucnews')` |
5252
| [HYP](https://pan.webis.de/semeval19/semeval19-web/) | 英文政治新闻情感分类语料 | `paddlenlp.datasets.load_dataset('hyp')` |
53-
| [XNLI](https://github.com/facebookresearch/XNLI) | 15种语言自然语言推理数据集,三分类任务. | `paddlenlp.datasets.load_dataset('xnli')`|
53+
| [XNLI](https://github.com/facebookresearch/XNLI) | 15种语言自然语言推理数据集,三分类任务. | `paddlenlp.datasets.load_dataset('xnli', 'ar')`|
5454
| [XNLI_CN](https://github.com/facebookresearch/XNLI) | 中文自然语言推理数据集(XNLI的子集),三分类任务. | `paddlenlp.datasets.load_dataset('xnli_cn')`|
5555

5656
## 文本匹配

examples/language_model/ernie-m/run_classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def convert_example(example, tokenizer, max_seq_length=256):
189189

190190

191191
def get_test_dataloader(args, language, batchify_fn, trans_func):
192-
test_ds = load_dataset("xnli", splits="test", language=language)
192+
test_ds = load_dataset("xnli", language, splits="test")
193193
test_ds = test_ds.map(trans_func, lazy=True)
194194
test_batch_sampler = BatchSampler(
195195
test_ds, batch_size=args.batch_size, shuffle=False)
@@ -240,12 +240,12 @@ def do_train(args):
240240
tokenizer=tokenizer,
241241
max_seq_length=args.max_seq_length)
242242
if args.task_type == "cross-lingual-transfer":
243-
train_ds = load_dataset("xnli", splits="train", language="en")
243+
train_ds = load_dataset("xnli", "en", splits="train")
244244
train_ds = train_ds.map(trans_func, lazy=True)
245245
elif args.task_type == "translate-train-all":
246246
all_train_ds = []
247247
for language in all_languages:
248-
train_ds = load_dataset("xnli", splits="train", language=language)
248+
train_ds = load_dataset("xnli", language, splits="train")
249249
all_train_ds.append(train_ds.map(trans_func, lazy=True))
250250
train_ds = XnliDataset(all_train_ds)
251251
train_batch_sampler = DistributedBatchSampler(

paddlenlp/datasets/cnn_dailymail.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
import collections
1717
import os
1818
import hashlib
19+
import shutil
1920

2021
from paddle.dataset.common import md5file
21-
from paddlenlp.utils.downloader import get_path_from_url, _decompress
22+
from paddle.utils.download import get_path_from_url, _decompress, _get_unique_endpoints
23+
from paddle.distributed import ParallelEnv
2224
from paddlenlp.utils.env import DATA_HOME
2325
from paddlenlp.utils.log import logger
2426
from . import DatasetBuilder
@@ -190,13 +192,17 @@ def _get_data(self, mode):
190192
dir_path = os.path.join(default_root, k)
191193
if not os.path.exists(dir_path):
192194
get_path_from_url(v["url"], default_root, v["md5"])
193-
file_num = len(os.listdir(os.path.join(dir_path, "stories")))
194-
if file_num != v["file_num"]:
195-
logger.warning(
196-
"Number of %s stories is %d != %d, decompress again." %
197-
(k, file_num, v["file_num"]))
198-
_decompress(
199-
os.path.join(default_root, os.path.basename(v["url"])))
195+
unique_endpoints = _get_unique_endpoints(ParallelEnv()
196+
.trainer_endpoints[:])
197+
if ParallelEnv().current_endpoint in unique_endpoints:
198+
file_num = len(os.listdir(os.path.join(dir_path, "stories")))
199+
if file_num != v["file_num"]:
200+
logger.warning(
201+
"Number of %s stories is %d != %d, decompress again." %
202+
(k, file_num, v["file_num"]))
203+
shutil.rmtree(os.path.join(dir_path, "stories"))
204+
_decompress(
205+
os.path.join(default_root, os.path.basename(v["url"])))
200206
dl_paths[k] = dir_path
201207
filename, url, data_hash = self.SPLITS[mode]
202208
fullname = os.path.join(default_root, filename)

paddlenlp/datasets/xnli.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
import os
1717
import csv
1818
from contextlib import ExitStack
19+
import shutil
1920

2021
from paddle.dataset.common import md5file
21-
from paddle.utils.download import get_path_from_url, _decompress
22+
from paddle.utils.download import get_path_from_url, _decompress, _get_unique_endpoints
23+
from paddle.distributed import ParallelEnv
2224
from paddlenlp.utils.env import DATA_HOME
2325
from paddlenlp.utils.log import logger
2426
from . import DatasetBuilder
@@ -64,12 +66,17 @@ def _get_data(self, mode, **kwargs):
6466
if mode == 'train':
6567
if not os.path.exists(fullname):
6668
get_path_from_url(url, default_root, zipfile_hash)
67-
file_num = len(os.listdir(fullname))
68-
if file_num != 15:
69-
logger.warning(
70-
"Number of train files is %d != %d, decompress again." %
71-
(file_num, 15))
72-
_decompress(os.path.join(default_root, os.path.basename(url)))
69+
unique_endpoints = _get_unique_endpoints(ParallelEnv()
70+
.trainer_endpoints[:])
71+
if ParallelEnv().current_endpoint in unique_endpoints:
72+
file_num = len(os.listdir(fullname))
73+
if file_num != len(ALL_LANGUAGES):
74+
logger.warning(
75+
"Number of train files is %d != %d, decompress again." %
76+
(file_num, len(ALL_LANGUAGES)))
77+
shutil.rmtree(fullname)
78+
_decompress(
79+
os.path.join(default_root, os.path.basename(url)))
7380
else:
7481
if not os.path.exists(fullname) or (
7582
data_hash and not md5file(fullname) == data_hash):
@@ -79,7 +86,7 @@ def _get_data(self, mode, **kwargs):
7986

8087
def _read(self, filename, split):
8188
"""Reads data."""
82-
language = self.config.get("language", "all_languages")
89+
language = self.name
8390
if language == "all_languages":
8491
languages = ALL_LANGUAGES
8592
else:

0 commit comments

Comments
 (0)