Skip to content

Commit 332112a

Browse files
authored
Change load_dataset design for loading local datasets (#565)
* Fix dataset doc * fix dataset doc * Change load_dataset design for loading local dataset * Update softmax_with_cross_entropy to cross_entropy
1 parent 17fe183 commit 332112a

File tree

5 files changed

+44
-43
lines changed

5 files changed

+44
-43
lines changed

docs/data_prepare/dataset_load.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,13 @@
5555
5656
.. note::
5757

58-
对于某些数据集,不同的split的读取方式不同。对于这种情况则需要在 :attr:`data_files` 参数中以字典的形式传入split信息。以 **COLA** 数据集为例:
58+
对于某些数据集,不同的split的读取方式不同。对于这种情况则需要在 :attr:`splits` 参数中以传入与 :attr:`data_files` **一一对应** 的split信息。
59+
60+
此时 :attr:`splits` 不再代表选取的内置数据集,而代表以何种格式读取本地数据集。
61+
62+
下面以 **COLA** 数据集为例:
5963

6064
.. code-block::
6165
6266
>>> from paddlenlp.datasets import load_dataset
63-
>>> train_ds, test_ds = load_dataset("glue", "cola", data_files={"train": "my_train_file.csv", "test": "my_test_file.csv"})
67+
>>> train_ds, test_ds = load_dataset("glue", "cola", splits=["train", "test"], data_files=["my_train_file.csv", "my_test_file.csv"])

examples/language_model/rnnlm/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def __init__(self):
7777

7878
def forward(self, y, label):
7979
label = paddle.unsqueeze(label, axis=2)
80-
loss = paddle.nn.functional.softmax_with_cross_entropy(
81-
logits=y, label=label, soft_label=False)
80+
loss = paddle.nn.functional.cross_entropy(
81+
input=y, label=label, reduction='none')
8282
loss = paddle.squeeze(loss, axis=[2])
8383
loss = paddle.mean(loss, axis=[0])
8484
loss = paddle.sum(loss)

examples/machine_reading_comprehension/DuReader-robust/run_du.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,10 @@ def forward(self, y, label):
101101
start_position, end_position = label
102102
start_position = paddle.unsqueeze(start_position, axis=-1)
103103
end_position = paddle.unsqueeze(end_position, axis=-1)
104-
start_loss = paddle.nn.functional.softmax_with_cross_entropy(
105-
logits=start_logits, label=start_position, soft_label=False)
106-
start_loss = paddle.mean(start_loss)
107-
end_loss = paddle.nn.functional.softmax_with_cross_entropy(
108-
logits=end_logits, label=end_position, soft_label=False)
109-
end_loss = paddle.mean(end_loss)
110-
104+
start_loss = paddle.nn.functional.cross_entropy(
105+
input=start_logits, label=start_position)
106+
end_loss = paddle.nn.functional.cross_entropy(
107+
input=end_logits, label=end_position)
111108
loss = (start_loss + end_loss) / 2
112109
return loss
113110

examples/machine_reading_comprehension/SQuAD/run_squad.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,10 @@ def forward(self, y, label):
204204
start_position, end_position = label
205205
start_position = paddle.unsqueeze(start_position, axis=-1)
206206
end_position = paddle.unsqueeze(end_position, axis=-1)
207-
start_loss = paddle.nn.functional.softmax_with_cross_entropy(
208-
logits=start_logits, label=start_position, soft_label=False)
209-
start_loss = paddle.mean(start_loss)
210-
end_loss = paddle.nn.functional.softmax_with_cross_entropy(
211-
logits=end_logits, label=end_position, soft_label=False)
212-
end_loss = paddle.mean(end_loss)
213-
207+
start_loss = paddle.nn.functional.cross_entropy(
208+
input=start_logits, label=start_position)
209+
end_loss = paddle.nn.functional.cross_entropy(
210+
input=end_logits, label=end_position)
214211
loss = (start_loss + end_loss) / 2
215212
return loss
216213

paddlenlp/datasets/dataset.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -472,30 +472,6 @@ def __init__(self, lazy=None, name=None, **config):
472472
def read_datasets(self, splits=None, data_files=None):
473473
datasets = []
474474
assert splits or data_files, "`data_files` and `splits` can not both be None."
475-
assert splits is None or data_files is None , "Only one of `data_files` and `splits` can be set."
476-
477-
if data_files:
478-
assert isinstance(data_files, str) or isinstance(
479-
data_files, dict
480-
) or isinstance(data_files, tuple) or isinstance(
481-
data_files, list
482-
), "`data_files` should be a string or tuple or list or a dictionary whose key is split name and value is the path of data file."
483-
if isinstance(data_files, str):
484-
split = 'train'
485-
datasets.append(self.read(filename=data_files, split=split))
486-
elif isinstance(data_files, tuple) or isinstance(data_files, list):
487-
split = 'train'
488-
datasets += [
489-
self.read(
490-
filename=filename, split=split)
491-
for filename in data_files
492-
]
493-
else:
494-
datasets += [
495-
self.read(
496-
filename=filename, split=split)
497-
for split, filename in data_files.items()
498-
]
499475

500476
def remove_if_exit(filepath):
501477
if isinstance(filepath, (list, tuple)):
@@ -510,7 +486,7 @@ def remove_if_exit(filepath):
510486
except OSError:
511487
pass
512488

513-
if splits:
489+
if splits and data_files is None:
514490
assert isinstance(splits, str) or (
515491
isinstance(splits, list) and isinstance(splits[0], str)
516492
) or (
@@ -551,6 +527,33 @@ def remove_if_exit(filepath):
551527
time.sleep(1)
552528
datasets.append(self.read(filename=filename, split=split))
553529

530+
if data_files:
531+
assert isinstance(data_files, str) or isinstance(
532+
data_files, tuple) or isinstance(
533+
data_files, list
534+
), "`data_files` should be a string or tuple or list of strings."
535+
536+
if isinstance(data_files, str):
537+
data_files = [data_files]
538+
default_split = 'train'
539+
if splits:
540+
if isinstance(splits, str):
541+
splits = [splits]
542+
assert len(splits) == len(
543+
data_files
544+
), "Number of `splits` and number of `data_files` should be the same if you want to specify the split of loacl data file."
545+
datasets += [
546+
self.read(
547+
filename=data_files[i], split=splits[i])
548+
for i in range(len(data_files))
549+
]
550+
else:
551+
datasets += [
552+
self.read(
553+
filename=data_files[i], split=default_split)
554+
for i in range(len(data_files))
555+
]
556+
554557
return datasets if len(datasets) > 1 else datasets[0]
555558

556559
def read(self, filename, split='train'):

0 commit comments

Comments
 (0)