Skip to content

Commit 855c9e3

Browse files
committed
clean softmax_op code
1 parent 24d51de commit 855c9e3

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

paddle/fluid/operators/softmax_op.h

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
3131
// allocate memory on device.
3232
Out->mutable_data<T>(context.GetPlace());
3333

34-
auto dims = X->dims();
35-
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
36-
framework::LoDTensor flattened_x;
37-
framework::LoDTensor flattened_out;
38-
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
39-
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
34+
int rank = X->dims().size();
35+
Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X;
36+
Tensor Out_2d =
37+
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
4038

4139
math::SoftmaxFunctor<DeviceContext, T>()(
42-
context.template device_context<DeviceContext>(), &flattened_x,
43-
&flattened_out);
40+
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
4441
}
4542
};
4643

@@ -55,18 +52,16 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
5552
// allocate memory on device.
5653
dX->mutable_data<T>(context.GetPlace());
5754

58-
auto dims = Out->dims();
59-
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
60-
framework::LoDTensor flattened_out;
61-
framework::LoDTensor flattened_d_out;
62-
framework::LoDTensor flattened_d_x;
63-
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
64-
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
65-
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
55+
int rank = Out->dims().size();
56+
Tensor Out_2d =
57+
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
58+
Tensor dOut_2d =
59+
rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut;
60+
Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX;
6661

6762
math::SoftmaxGradFunctor<DeviceContext, T>()(
68-
context.template device_context<DeviceContext>(), &flattened_out,
69-
&flattened_d_out, &flattened_d_x);
63+
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
64+
&dX_2d);
7065
}
7166
};
7267

0 commit comments

Comments
 (0)