@@ -1388,6 +1388,8 @@ def cross_entropy(input,
1388
1388
"should be '-100', but received %s, which is not allowed." %
1389
1389
ignore_index )
1390
1390
1391
+ softmax_switch = use_softmax
1392
+
1391
1393
input_dims = len (list (input .shape ))
1392
1394
label_dims = len (list (label .shape ))
1393
1395
if input_dims - 1 != label_dims and input_dims != label_dims :
@@ -1400,7 +1402,7 @@ def cross_entropy(input,
1400
1402
_ , out = core .ops .softmax_with_cross_entropy (
1401
1403
input , label , 'soft_label' , soft_label , 'ignore_index' ,
1402
1404
ignore_index , 'numeric_stable_mode' , True , 'axis' , axis ,
1403
- 'use_softmax ' , use_softmax )
1405
+ 'softmax_switch ' , softmax_switch )
1404
1406
1405
1407
if weight is not None :
1406
1408
@@ -1482,7 +1484,7 @@ def cross_entropy(input,
1482
1484
'ignore_index' : ignore_index ,
1483
1485
'numeric_stable_mode' : True ,
1484
1486
'axis' : axis ,
1485
- 'use_softmax ' : use_softmax
1487
+ 'softmax_switch ' : softmax_switch
1486
1488
}
1487
1489
helper = LayerHelper ('softmax_with_cross_entropy' , ** locals ())
1488
1490
softmax = helper .create_variable_for_type_inference (dtype = input .dtype )
0 commit comments