Skip to content

Commit 3a25588

Browse files
authored
fix use_softmax=False does not work, test=develop (#32035)
1 parent 1f8834a commit 3a25588

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,8 @@ def cross_entropy(input,
13881388
"should be '-100', but received %s, which is not allowed." %
13891389
ignore_index)
13901390

1391+
softmax_switch = use_softmax
1392+
13911393
input_dims = len(list(input.shape))
13921394
label_dims = len(list(label.shape))
13931395
if input_dims - 1 != label_dims and input_dims != label_dims:
@@ -1400,7 +1402,7 @@ def cross_entropy(input,
14001402
_, out = core.ops.softmax_with_cross_entropy(
14011403
input, label, 'soft_label', soft_label, 'ignore_index',
14021404
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
1403-
'use_softmax', use_softmax)
1405+
'softmax_switch', softmax_switch)
14041406

14051407
if weight is not None:
14061408

@@ -1482,7 +1484,7 @@ def cross_entropy(input,
14821484
'ignore_index': ignore_index,
14831485
'numeric_stable_mode': True,
14841486
'axis': axis,
1485-
'use_softmax': use_softmax
1487+
'softmax_switch': softmax_switch
14861488
}
14871489
helper = LayerHelper('softmax_with_cross_entropy', **locals())
14881490
softmax = helper.create_variable_for_type_inference(dtype=input.dtype)

0 commit comments

Comments
 (0)