Skip to content

Commit 6c7b64c

Browse files
author
Yibing Liu
authored
Support softmax return in softmax_with_cross_entropy (#14367)
* Support softmax return in softmax_with_cross_entropy * Add test for return_softmax=False test=develop
1 parent df826de commit 6c7b64c

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
103103
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
104104
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
105105
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
106-
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode'], varargs=None, keywords=None, defaults=(False, -100, False))
106+
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, False, False))
107107
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
108108
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
109109
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))

python/paddle/fluid/layers/nn.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4742,7 +4742,8 @@ def softmax_with_cross_entropy(logits,
47424742
label,
47434743
soft_label=False,
47444744
ignore_index=-100,
4745-
numeric_stable_mode=False):
4745+
numeric_stable_mode=False,
4746+
return_softmax=False):
47464747
"""
47474748
**Softmax With Cross Entropy Operator.**
47484749
@@ -4806,9 +4807,15 @@ def softmax_with_cross_entropy(logits,
48064807
the algorithm is always numerically stable.
48074808
Note that the speed may be slower when use
48084809
stable algorithm. Default: False
4810+
return_softmax (bool): A flag indicating whether to return the softmax
4811+
along with the cross entropy loss. Default: False
48094812
48104813
Returns:
4811-
Variable: The cross entropy loss is a 2-D tensor with shape [N x 1].
4814+
Variable or Tuple of two Variables: Return the cross entropy loss if
4815+
`return_softmax` is False, otherwise the tuple
4816+
(loss, softmax), where the cross entropy loss is
4817+
a 2-D tensor with shape [N x 1], and softmax is a
4818+
2-D tensor with shape [N x K].
48124819
48134820
Examples:
48144821
.. code-block:: python
@@ -4833,6 +4840,10 @@ def softmax_with_cross_entropy(logits,
48334840
'ignore_index': ignore_index,
48344841
'numeric_stable_mode': numeric_stable_mode
48354842
})
4843+
4844+
if return_softmax:
4845+
return loss, softmax
4846+
48364847
return loss
48374848

48384849

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ def test_softmax_with_cross_entropy(self):
369369
with program_guard(program):
370370
x = layers.data(name='x', shape=[16], dtype='float32')
371371
y = layers.data(name='label', shape=[1], dtype='int64')
372+
loss, softmax = layers.softmax_with_cross_entropy(
373+
x, y, return_softmax=True)
374+
self.assertIsNotNone(loss)
375+
self.assertIsNotNone(softmax)
372376
loss = layers.softmax_with_cross_entropy(x, y)
373377
self.assertIsNotNone(loss)
374378
print(str(program))

0 commit comments

Comments
 (0)