Skip to content

Commit 97cd708

Browse files
authored
cherry-pick:add softmax_switch for softmax_with_cross_entropy_op (#32105)
* cherry-pick:add softmax_switch for softmax_with_cross_entropy_op, test=develop * add softmax_switch for softmax_with_cross_entropy_op, test=develop * delete using EigenMatrix in softmax_with_cross_entropy_op.h, test=develop * add REGISTER_OP_VERSION for softmax_switch attr of softmax_with_cross_entropy_op, test=develop * cherry-pick:add softmax_switch for softmax_with_cross_entropy_op,test=develop * change softmax_switch to use_softmax, test=develop * fix code format for softmax_with_cross_entropy_op.cc, test=develop
1 parent 1b3cd0f commit 97cd708

File tree

5 files changed

+705
-14
lines changed

5 files changed

+705
-14
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <string>
1818
#include <unordered_map>
1919
#include <vector>
20+
#include "paddle/fluid/framework/op_version_registry.h"
2021

2122
namespace paddle {
2223
namespace operators {
@@ -53,6 +54,10 @@ class SoftmaxWithCrossEntropyOpMaker
5354
"(bool, default: false), A flag to indicate whether to interpretant "
5455
"the given labels as soft labels.")
5556
.SetDefault(false);
57+
AddAttr<bool>(
58+
"use_softmax",
59+
"(bool, default: true), A flag to indicate whether to do softmax ")
60+
.SetDefault(true);
5661
AddAttr<bool>(
5762
"numeric_stable_mode",
5863
"(bool, default: true), A flag to indicate whether to use more "
@@ -312,3 +317,10 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
312317
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
313318
ops::SoftmaxWithCrossEntropyGradKernel<float>,
314319
ops::SoftmaxWithCrossEntropyGradKernel<double>);
320+
321+
REGISTER_OP_VERSION(softmax_with_cross_entropy)
322+
.AddCheckpoint(
323+
R"ROC(
324+
Add a new attribute [use_softmax] )ROC",
325+
paddle::framework::compatible::OpVersionDesc().NewAttr(
326+
"use_softmax", "A flag to indicate whether to do softmax", true));

0 commit comments

Comments
 (0)