Skip to content

Commit 9a589de

Browse files
authored
cherry-pick:change softmax_with_cross_entropy_op's parameter name from softmax_switch to use_softmax (#32750)
* change parameter name from softmax_switch to use_softmax, test=develop * cherry-pick:change parameter name from softmax_switch to use_softmax, test=develop
1 parent 0bb079c commit 9a589de

File tree

5 files changed

+56
-59
lines changed

5 files changed

+56
-59
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class SoftmaxWithCrossEntropyOpMaker
5555
"the given labels as soft labels.")
5656
.SetDefault(false);
5757
AddAttr<bool>(
58-
"softmax_switch",
58+
"use_softmax",
5959
"(bool, default: true), A flag to indicate whether to do softmax ")
6060
.SetDefault(true);
6161
AddAttr<bool>(
@@ -320,7 +320,6 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
320320
REGISTER_OP_VERSION(softmax_with_cross_entropy)
321321
.AddCheckpoint(
322322
R"ROC(
323-
Add a new attribute [softmax_switch] )ROC",
323+
Add a new attribute [use_softmax] )ROC",
324324
paddle::framework::compatible::OpVersionDesc().NewAttr(
325-
"softmax_switch", "A flag to indicate whether to do softmax",
326-
true));
325+
"use_softmax", "A flag to indicate whether to do softmax", true));

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -772,10 +772,10 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
772772
platform::is_gpu_place(context.GetPlace()), true,
773773
platform::errors::Unavailable("softmax_with_cross_entropy operator's "
774774
"CUDA kernel only runs on GPU device."));
775-
const bool softmax_switch = context.Attr<bool>("softmax_switch");
775+
const bool use_softmax = context.Attr<bool>("use_softmax");
776776

777777
// do not with softmax op, and input is softmax
778-
if (!softmax_switch) {
778+
if (!use_softmax) {
779779
const Tensor* softmax = context.Input<Tensor>("Logits");
780780
const Tensor* labels = context.Input<Tensor>("Label");
781781
Tensor* softmax_out = context.Output<Tensor>("Softmax");
@@ -925,10 +925,10 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
925925
int block = 512;
926926
auto stream = context.cuda_device_context().stream();
927927
auto ignore_index = context.Attr<int>("ignore_index");
928-
auto softmax_switch = context.Attr<bool>("softmax_switch");
928+
auto use_softmax = context.Attr<bool>("use_softmax");
929929

930930
// do not with softmax op, and input is softmax
931-
if (!softmax_switch) {
931+
if (!use_softmax) {
932932
if (context.Attr<bool>("soft_label")) {
933933
int grid = (n * d + block - 1) / block;
934934
const T* label_data = labels->data<T>();

paddle/fluid/operators/softmax_with_cross_entropy_op.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
3131
PADDLE_ENFORCE_EQ(
3232
platform::is_cpu_place(context.GetPlace()), true,
3333
platform::errors::Unimplemented("This kernel only runs on CPU."));
34-
const bool softmax_switch = context.Attr<bool>("softmax_switch");
34+
const bool use_softmax = context.Attr<bool>("use_softmax");
3535

3636
// do not with softmax op, and input is softmax
37-
if (!softmax_switch) {
37+
if (!use_softmax) {
3838
const Tensor* softmax = context.Input<Tensor>("Logits");
3939
const Tensor* labels = context.Input<Tensor>("Label");
4040
Tensor* softmax_out = context.Output<Tensor>("Softmax");
@@ -113,9 +113,9 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
113113
context.Output<Tensor>(framework::GradVarName("Logits"));
114114

115115
const Tensor* softmax = context.Input<Tensor>("Softmax");
116-
const bool softmax_switch = context.Attr<bool>("softmax_switch");
116+
const bool use_softmax = context.Attr<bool>("use_softmax");
117117

118-
if (logit_grad != softmax || !softmax_switch) {
118+
if (logit_grad != softmax || !use_softmax) {
119119
framework::TensorCopy(*softmax, context.GetPlace(),
120120
context.device_context(), logit_grad);
121121
}
@@ -138,8 +138,8 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
138138
auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
139139
auto& place = *context.template device_context<platform::CPUDeviceContext>()
140140
.eigen_device();
141-
if (!softmax_switch) {
142-
// softmax_switch step1
141+
if (!use_softmax) {
142+
// use_softmax step1
143143
if (soft_label) {
144144
auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
145145
logit_grad_mat.device(place) =
@@ -148,7 +148,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
148148
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
149149
logit_grad_mat;
150150
}
151-
// softmax_switch step2
151+
// use_softmax step2
152152
else {
153153
const int64_t* label_data = labels->data<int64_t>();
154154
T* logit_grad_data = logit_grad->data<T>();
@@ -181,7 +181,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
181181
return;
182182
}
183183

184-
// for softmax_switch=False, continue
184+
// for use_softmax=False, continue
185185

186186
if (soft_label) {
187187
// when soft_label = True, ignore_index is not supported

0 commit comments

Comments
 (0)