Skip to content

Commit b69d56f

Browse files
author
gongenlei
authored
Refactor dataset xnli/cnn_dailymail (#1838)
* refactor: add name error for xnli * add default language * use name not version
1 parent c4fc6ea commit b69d56f

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

paddlenlp/datasets/cnn_dailymail.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ def fix_missing_period(line):
188188
def _get_data(self, mode):
189189
""" Check and download Dataset """
190190
dl_paths = {}
191-
version = self.config.get("version", "3.0.0")
191+
version = self.name
192+
if version is None:
193+
version = "3.0.0"
192194
if version not in ["1.0.0", "2.0.0", "3.0.0"]:
193195
raise ValueError("Unsupported version: %s" % version)
194196
dl_paths["version"] = version

paddlenlp/datasets/xnli.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
from . import DatasetBuilder
3232

3333
__all__ = ['XNLI']
34-
ALL_LANGUAGES = ("ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw",
35-
"th", "tr", "ur", "vi", "zh")
34+
ALL_LANGUAGES = [
35+
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
36+
"ur", "vi", "zh"
37+
]
3638

3739

3840
class XNLI(DatasetBuilder):
@@ -92,6 +94,12 @@ def _get_data(self, mode, **kwargs):
9294
def _read(self, filename, split):
9395
"""Reads data."""
9496
language = self.name
97+
if language is None:
98+
language = "all_languages"
99+
if language not in ALL_LANGUAGES + ["all_languages"]:
100+
raise ValueError(
101+
f"Name parameter should be specified. Can be one of {ALL_LANGUAGES + ['all_languages']}. "
102+
)
95103
if language == "all_languages":
96104
languages = ALL_LANGUAGES
97105
else:

0 commit comments

Comments
 (0)