Skip to content

Commit fa1e0ba

Browse files
authored
update example cross entropy API usage (#572)
1 parent 332112a commit fa1e0ba

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

examples/text_generation/couplet/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __init__(self):
2323
super(CrossEntropyCriterion, self).__init__()
2424

2525
def forward(self, predict, label, trg_mask):
26-
cost = F.softmax_with_cross_entropy(
27-
logits=predict, label=label, soft_label=False)
26+
cost = F.cross_entropy(
27+
input=predict, label=label, reduction='none', soft_label=False)
2828
cost = paddle.squeeze(cost, axis=[2])
2929
masked_cost = cost * trg_mask
3030
batch_mean_cost = paddle.mean(masked_cost, axis=[0])

examples/text_generation/couplet/train.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
from args import parse_args
1616

17-
import numpy as np
1817
import paddle
19-
import paddle.nn as nn
20-
import paddle.nn.functional as F
2118
from paddlenlp.metrics import Perplexity
2219

2320
from data import create_train_loader

examples/text_generation/vae-seq2seq/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def forward(self, kl_loss, dec_output, trg_mask, label):
4040
self.update_kl_weight()
4141
self.kl_loss = kl_loss
4242

43-
rec_loss = F.softmax_with_cross_entropy(
44-
logits=dec_output, label=label, soft_label=False)
43+
rec_loss = F.cross_entropy(
44+
input=dec_output, label=label, reduction='none', soft_label=False)
4545

4646
rec_loss = paddle.squeeze(rec_loss, axis=[2])
4747
rec_loss = rec_loss * trg_mask
@@ -117,7 +117,9 @@ def __init__(self, ppl, nll, log_freq=200, verbose=2):
117117

118118
def on_train_begin(self, logs=None):
119119
super(TrainCallback, self).on_train_begin(logs)
120-
self.train_metrics = ["loss", "ppl", "nll", "kl weight", "kl loss", "rec loss"]
120+
self.train_metrics = [
121+
"loss", "ppl", "nll", "kl weight", "kl loss", "rec loss"
122+
]
121123

122124
def on_epoch_begin(self, epoch=None, logs=None):
123125
super(TrainCallback, self).on_epoch_begin(epoch, logs)

0 commit comments

Comments
 (0)