Skip to content

Commit ff17c36

Browse files
authored
Remove F.softmax_with_cross_entropy in paddlenlp (#484)
* upgrade F.cross_entropy usage * fix sample code bug * fix ppl shape error
1 parent bac370b commit ff17c36

File tree

7 files changed

+37
-22
lines changed

7 files changed

+37
-22
lines changed

examples/machine_translation/seq2seq/seq2seq_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def __init__(self):
2525
super(CrossEntropyCriterion, self).__init__()
2626

2727
def forward(self, predict, label, trg_mask):
28-
cost = F.softmax_with_cross_entropy(
29-
logits=predict, label=label, soft_label=False)
28+
cost = F.cross_entropy(
29+
input=predict, label=label, soft_label=False, reduction='none')
3030
cost = paddle.squeeze(cost, axis=[2])
3131
masked_cost = cost * trg_mask
3232
batch_mean_cost = paddle.mean(masked_cost, axis=[0])

paddlenlp/metrics/perplexity.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ def __init__(self, name='Perplexity', *args, **kwargs):
4646
self.total_word_num = 0
4747

4848
def compute(self, pred, label, seq_mask=None):
49-
label = paddle.unsqueeze(label, axis=2)
50-
ce = F.softmax_with_cross_entropy(
51-
logits=pred, label=label, soft_label=False)
49+
if label.dim() == 2:
50+
label = paddle.unsqueeze(label, axis=2)
51+
ce = F.cross_entropy(
52+
input=pred, label=label, reduction='none', soft_label=False)
5253
ce = paddle.squeeze(ce, axis=[2])
5354
if seq_mask is not None:
5455
ce = ce * seq_mask

paddlenlp/transformers/bert/modeling.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,9 +535,12 @@ def __init__(self, vocab_size):
535535
def forward(self, prediction_scores, seq_relationship_score,
536536
masked_lm_labels, next_sentence_labels, masked_lm_scale):
537537
with paddle.static.amp.fp16_guard():
538-
masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy(
539-
prediction_scores, masked_lm_labels, ignore_index=-1)
538+
masked_lm_loss = F.cross_entropy(
539+
prediction_scores,
540+
masked_lm_labels,
541+
reduction='none',
542+
ignore_index=-1)
540543
masked_lm_loss = masked_lm_loss / masked_lm_scale
541-
next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy(
542-
seq_relationship_score, next_sentence_labels)
544+
next_sentence_loss = F.cross_entropy(
545+
seq_relationship_score, next_sentence_labels, reduction='none')
543546
return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss)

paddlenlp/transformers/bigbird/modeling.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,14 +862,17 @@ def forward(self, prediction_scores, seq_relationship_score,
862862
masked_lm_scale, masked_lm_weights)
863863
print(loss)
864864
"""
865-
masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy(
866-
prediction_scores, masked_lm_labels, ignore_index=self.ignore_index)
865+
masked_lm_loss = F.cross_entropy(
866+
prediction_scores,
867+
masked_lm_labels,
868+
ignore_index=self.ignore_index,
869+
reduction='none')
867870
masked_lm_loss = paddle.transpose(masked_lm_loss, [1, 0])
868871
masked_lm_loss = paddle.sum(masked_lm_loss * masked_lm_weights) / (
869872
paddle.sum(masked_lm_weights) + 1e-5)
870873
scale = 1.0
871874
if not self.use_nsp:
872875
scale = 0.0
873-
next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy(
874-
seq_relationship_score, next_sentence_labels)
876+
next_sentence_loss = F.cross_entropy(
877+
seq_relationship_score, next_sentence_labels, reduction='none')
875878
return masked_lm_loss + paddle.mean(next_sentence_loss) * scale

paddlenlp/transformers/ernie/modeling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import paddle
1616
import paddle.nn as nn
17+
import paddle.nn.functional as F
1718

1819
from .. import PretrainedModel, register_base_model
1920

@@ -772,9 +773,12 @@ def __init__(self, vocab_size):
772773
def forward(self, prediction_scores, seq_relationship_score,
773774
masked_lm_labels, next_sentence_labels, masked_lm_scale):
774775
with paddle.static.amp.fp16_guard():
775-
masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy(
776-
prediction_scores, masked_lm_labels, ignore_index=-1)
776+
masked_lm_loss = F.cross_entropy(
777+
prediction_scores,
778+
masked_lm_labels,
779+
ignore_index=-1,
780+
reduction='none')
777781
masked_lm_loss = masked_lm_loss / masked_lm_scale
778-
next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy(
779-
seq_relationship_score, next_sentence_labels)
782+
next_sentence_loss = F.cross_entropy(
783+
seq_relationship_score, next_sentence_labels, reduction='none')
780784
return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss)

paddlenlp/transformers/ernie_gen/modeling.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,10 @@ def forward(self, *args, **kwargs):
607607
if len(tgt_labels.shape) == 1:
608608
tgt_labels = paddle.reshape(tgt_labels, [-1, 1])
609609

610-
loss = paddle.nn.functional.cross_entropy(
611-
logits_2d, tgt_labels, soft_label=(tgt_labels.shape[-1] != 1))
610+
loss = F.cross_entropy(
611+
logits_2d,
612+
tgt_labels,
613+
reduction="none",
614+
soft_label=(tgt_labels.shape[-1] != 1))
612615

613616
return loss, logits_2d, info

paddlenlp/transformers/transformer/modeling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def forward(self, predict, label):
252252
label = paddle.randint(
253253
low=3,
254254
high=vocab_size,
255-
shape=[batch_size, seq_len, vocab_size])
255+
shape=[batch_size, seq_len, 1])
256256
257257
criterion(predict, label)
258258
"""
@@ -265,9 +265,10 @@ def forward(self, predict, label):
265265
x=label, num_classes=predict.shape[-1]),
266266
epsilon=self.label_smooth_eps)
267267

268-
cost = F.softmax_with_cross_entropy(
269-
logits=predict,
268+
cost = F.cross_entropy(
269+
input=predict,
270270
label=label,
271+
reduction='none',
271272
soft_label=True if self.label_smooth_eps else False)
272273
weighted_cost = cost * weights
273274
sum_cost = paddle.sum(weighted_cost)

0 commit comments

Comments
 (0)