Skip to content

Commit b314a69

Browse files
committed
make softmax supporting tensors
1 parent b1af7e5 commit b314a69

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

paddle/fluid/operators/softmax_op.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
3737
PADDLE_ENFORCE(ctx->HasOutput("Out"),
3838
"Output(Out) of SoftmaxOp should not be null.");
3939

40-
auto x_dims = ctx->GetInputDim("X");
41-
PADDLE_ENFORCE(x_dims.size() == 2UL,
42-
"The input of softmax op must be a matrix.");
43-
ctx->SetOutputDim("Out", x_dims);
40+
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
4441
ctx->ShareLoD("X", /*->*/ "Out");
4542
}
4643

@@ -81,8 +78,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
8178
public:
8279
void Make() override {
8380
AddInput("X",
84-
"The input tensor of softmax. "
85-
"2-D with shape [batch_size, input_feature_dimensions].");
81+
"The input tensor of softmax, "
82+
"whose last dimension is the input_feature_dimensions.");
8683
AddOutput("Out", "The normalized values with the same shape as X.")
8784
.Reuse("X");
8885
AddAttr<bool>(
@@ -105,20 +102,23 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
105102
AddComment(R"DOC(
106103
Softmax Operator.
107104
108-
The input of the softmax operator is a 2-D tensor with shape N x K (N is the
109-
batch_size, K is the dimension of input feature). The output tensor has the
110-
same shape as the input tensor.
105+
The input of the softmax operator is a tensor of any rank. The output tensor
106+
has the same shape as the input.
111107
112-
For each row of the input tensor, the softmax operator squashes the
113-
K-dimensional vector of arbitrary real values to a K-dimensional vector of real
114-
values in the range [0, 1] that add up to 1.
108+
The input tensor will first be logically flattened to a 2-D matrix. The matrix's
109+
second dimension(row length) is as same as the last dimension of the input
110+
tensor, and the first dimension(column length) is the product of all other
111+
dimensions of the input tensor. For each row of the matrix, the softmax operator
112+
squashes the K-dimensional(K is the width of the matrix, which is also the size
113+
of the input tensor's last dimension) vector of arbitrary real values to a
114+
K-dimensional vector of real values in the range [0, 1] that add up to 1.
115115
It computes the exponential of the given dimension and the sum of exponential
116116
values of all the other dimensions in the K-dimensional vector input.
117117
Then the ratio of the exponential of the given dimension and the sum of
118118
exponential values of all the other dimensions is the output of the softmax
119119
operator.
120120
121-
For each row $i$ and each column $j$ in Input(X), we have:
121+
For each row $i$ and each column $j$ in the matrix, we have:
122122
$$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
123123
124124
)DOC");

paddle/fluid/operators/softmax_op.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,16 @@ class SoftmaxKernel : public framework::OpKernel<T> {
3131
// allocate memory on device.
3232
Out->mutable_data<T>(context.GetPlace());
3333

34+
auto dims = X->dims();
35+
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
36+
framework::LoDTensor flattened_x;
37+
framework::LoDTensor flattened_out;
38+
flattened_x.ShareDataWith(*X);
39+
flattened_out.ShareDataWith(*Out);
40+
3441
math::SoftmaxFunctor<DeviceContext, T>()(
35-
context.template device_context<DeviceContext>(), X, Out);
42+
context.template device_context<DeviceContext>(), &flattened_x,
43+
&flattened_out);
3644
}
3745
};
3846

@@ -47,8 +55,18 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
4755
// allocate memory on device.
4856
dX->mutable_data<T>(context.GetPlace());
4957

58+
auto dims = Out->dims();
59+
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
60+
framework::LoDTensor flattened_out;
61+
framework::LoDTensor flattened_d_out;
62+
framework::LoDTensor flattened_d_x;
63+
flattened_out.ShareDataWith(*Out);
64+
flattened_d_out.ShareDataWith(*dOut);
65+
flattened_d_x.ShareDataWith(*dX);
66+
5067
math::SoftmaxGradFunctor<DeviceContext, T>()(
51-
context.template device_context<DeviceContext>(), Out, dOut, dX);
68+
context.template device_context<DeviceContext>(), &flattened_out,
69+
&flattened_d_out, &flattened_d_x);
5270
}
5371
};
5472

0 commit comments

Comments
 (0)