Skip to content

Commit 062e41e

Browse files
authored
Fix community provided datasets and fix cross_entropy for elmo and dgu (#576)
* Fix dataset doc * fix dataset doc * Change load_dataset design for loading local dataset * Update softmax_with_cross_entropy to cross_entropy * Fix community provided datasets and fix cross_entropy for elmo and dgu
1 parent 6b1fa66 commit 062e41e

File tree

4 files changed

+35
-26
lines changed

4 files changed

+35
-26
lines changed

examples/dialogue/dgu/main.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,18 @@ def get_loss_fn(self):
6565
if self.task_name in [
6666
'udc', 'atis_slot', 'atis_intent', 'mrda', 'swda'
6767
]:
68-
return F.softmax_with_cross_entropy
68+
return F.cross_entropy
6969
elif self.task_name == 'dstc2':
7070
return nn.BCEWithLogitsLoss(reduction='sum')
7171

7272
def forward(self, logits, labels):
7373
if self.task_name in ['udc', 'atis_intent', 'mrda', 'swda']:
7474
loss = self.loss_fn(logits, labels)
75-
loss = paddle.mean(loss)
7675
elif self.task_name == 'dstc2':
7776
loss = self.loss_fn(logits, paddle.cast(labels, dtype=logits.dtype))
7877
elif self.task_name == 'atis_slot':
7978
labels = paddle.unsqueeze(labels, axis=-1)
8079
loss = self.loss_fn(logits, labels)
81-
loss = paddle.mean(loss)
8280
return loss
8381

8482

examples/language_model/elmo/elmo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,10 @@ def forward(self, x, y):
302302
bw_label = paddle.unsqueeze(bw_label, axis=2)
303303

304304
# [batch_size, seq_len, 1]
305-
fw_loss = F.softmax_with_cross_entropy(logits=fw_logits, label=fw_label)
306-
bw_loss = F.softmax_with_cross_entropy(logits=bw_logits, label=bw_label)
305+
fw_loss = F.cross_entropy(input=fw_logits, label=fw_label)
306+
bw_loss = F.cross_entropy(input=bw_logits, label=bw_label)
307307

308-
avg_loss = 0.5 * (paddle.mean(fw_loss) + paddle.mean(bw_loss))
308+
avg_loss = 0.5 * (fw_loss + bw_loss)
309309
return avg_loss
310310

311311

paddlenlp/datasets/bq_corpus.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212

1313
class BQCorpus(DatasetBuilder):
1414
"""
15-
BQCorpus: the largest dataset available for for the banking and finance sector
15+
BQCorpus: A Large-scale Domain-specific Chinese Corpus For Sentence
16+
Semantic Equivalence Identification. More information please refer
17+
to `https://www.aclweb.org/anthology/D18-1536.pdf`
1618
17-
by frozenfish123@Wuhan University
19+
Contributed by frozenfish123@Wuhan University
1820
1921
"""
2022
lazy = False
@@ -23,13 +25,13 @@ class BQCorpus(DatasetBuilder):
2325
META_INFO = collections.namedtuple('META_INFO', ('file', 'md5'))
2426
SPLITS = {
2527
'train': META_INFO(
26-
os.path.join('BQCorpus', 'train.tsv'),
28+
os.path.join('bq_corpus', 'bq_corpus', 'train.tsv'),
2729
'd37683e9ee778ee2f4326033b654adb9'),
2830
'dev': META_INFO(
29-
os.path.join('BQCorpus', 'dev.tsv'),
31+
os.path.join('bq_corpus', 'bq_corpus', 'dev.tsv'),
3032
'8a71f2a69453646921e9ee1aa457d1e4'),
3133
'test': META_INFO(
32-
os.path.join('BQCorpus', 'test.tsv'),
34+
os.path.join('bq_corpus', 'bq_corpus', 'test.tsv'),
3335
'c797995baa248b144ceaa4018b191e52'),
3436
}
3537

@@ -47,18 +49,18 @@ def _get_data(self, mode, **kwargs):
4749
def _read(self, filename):
4850
"""Reads data."""
4951
with open(filename, 'r', encoding='utf-8') as f:
50-
head = None
5152
for line in f:
5253
data = line.strip().split("\t")
53-
if not head:
54-
head = data
55-
else:
54+
if len(data) == 3:
5655
sentence1, sentence2, label = data
57-
yield {
58-
"sentence1": sentence1,
59-
"sentence2": sentence2,
60-
"label": label
61-
}
56+
elif len(data) == 2:
57+
sentence1, sentence2 = data
58+
label = ''
59+
yield {
60+
"sentence1": sentence1,
61+
"sentence2": sentence2,
62+
"label": label
63+
}
6264

6365
def get_labels(self):
6466
"""

paddlenlp/datasets/paws-x.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from paddlenlp.utils.env import DATA_HOME
2222
from . import DatasetBuilder
2323

24-
__all__ = ['PAWS']
24+
__all__ = ['PAWSX']
2525

26-
class PAWS(DatasetBuilder):
26+
27+
class PAWSX(DatasetBuilder):
2728
"""
2829
PAWS-X: A Cross-lingual Adversarial Dataset for Paraphrase Identification
2930
More information please refer to `https://arxiv.org/abs/1908.11828`
@@ -60,11 +61,19 @@ def _read(self, filename):
6061
for line in f:
6162
data = line.strip().split("\t")
6263
if len(data) == 3:
63-
sentence1, sentence2, label = data
64-
yield {"sentence1": sentence1, "sentence2": sentence2, "label": label}
64+
sentence1, sentence2, label = data
65+
yield {
66+
"sentence1": sentence1,
67+
"sentence2": sentence2,
68+
"label": label
69+
}
6570
elif len(data) == 2:
66-
sentence1, sentence2 = data
67-
yield {"sentence1": sentence1, "sentence2": sentence2, "label":''}
71+
sentence1, sentence2 = data
72+
yield {
73+
"sentence1": sentence1,
74+
"sentence2": sentence2,
75+
"label": ''
76+
}
6877
else:
6978
continue
7079

0 commit comments

Comments
 (0)