Skip to content

Commit b656d97

Browse files
authored
Merge pull request #12485 from JiayiFeng/dev_ops_tensor_support
Make lookup_table_op and softmax_op supporting high rank tensor
2 parents 3cea440 + 23aebf0 commit b656d97

File tree

10 files changed

+242
-80
lines changed

10 files changed

+242
-80
lines changed

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
3232

3333
auto table_dims = ctx->GetInputDim("W");
3434
auto ids_dims = ctx->GetInputDim("Ids");
35+
int ids_rank = ids_dims.size();
3536

36-
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
37-
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
37+
PADDLE_ENFORCE_EQ(table_dims.size(), 2);
38+
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
39+
"The last dimension of the 'Ids' tensor must be 1.");
3840

39-
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
41+
auto output_dims =
42+
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
43+
output_dims.push_back(table_dims[1]);
44+
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
4045

4146
if (ctx->GetOutputsVarType("Out")[0] ==
4247
framework::proto::VarType::LOD_TENSOR) {
@@ -61,8 +66,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
6166
AddInput("Ids",
6267
"An input with type int32 or int64 "
6368
"contains the ids to be looked up in W. "
64-
"Ids must be a column vector with rank = 2. "
65-
"The 2nd dimension size must be 1.");
69+
"The last dimension size must be 1.");
6670
AddOutput("Out", "The lookup results, which have the same type as W.");
6771
AddAttr<bool>("is_sparse",
6872
"(boolean, default false) "

paddle/fluid/operators/lookup_table_op.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,28 +118,31 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
118118
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
119119

120120
auto *ids_data = ids->data<int64_t>();
121-
auto ids_dim = ids->dims();
121+
int64_t ids_num = ids->numel();
122122

123123
auto stream = dev_ctx.stream();
124124
// copy GPU memory to CPU pinned memory
125125
framework::Vector<int64_t> new_rows;
126-
new_rows.resize(ids_dim[0]);
126+
new_rows.resize(ids_num);
127127
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
128128

129129
// TODO(yuyang18): Strange code here.
130130
memory::Copy(platform::CPUPlace(),
131131
new_rows.CUDAMutableData(context.GetPlace()), gpu_place,
132-
ids_data, ids_dim[0] * sizeof(int64_t), stream);
132+
ids_data, ids_num * sizeof(int64_t), stream);
133133

134134
d_table->set_rows(new_rows);
135135

136136
auto *d_table_value = d_table->mutable_value();
137-
d_table_value->Resize({ids_dim[0], table->dims()[1]});
137+
d_table_value->Resize({ids_num, table->dims()[1]});
138138
d_table_value->mutable_data<T>(context.GetPlace());
139139

140140
auto *d_table_data = d_table_value->data<T>();
141141
auto *d_output_data = d_output->data<T>();
142-
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
142+
auto d_output_dims = d_output->dims();
143+
PADDLE_ENFORCE_EQ(
144+
d_table_value->dims(),
145+
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
143146
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
144147
d_output->numel() * sizeof(T), stream);
145148

paddle/fluid/operators/lookup_table_op.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,36 +109,38 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
109109
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
110110

111111
auto *ids_data = ids->data<int64_t>();
112-
auto ids_dim = ids->dims();
112+
int64_t ids_num = ids->numel();
113113

114114
framework::Vector<int64_t> new_rows;
115-
new_rows.reserve(ids_dim[0]);
116-
for (int64_t i = 0; i < ids_dim[0]; i++) {
115+
new_rows.reserve(ids_num);
116+
for (int64_t i = 0; i < ids_num; i++) {
117117
new_rows.push_back(ids_data[i]);
118118
}
119119
d_table->set_rows(new_rows);
120120

121121
auto *d_table_value = d_table->mutable_value();
122-
d_table_value->Resize({ids_dim[0], table_dim[1]});
122+
d_table_value->Resize({ids_num, table_dim[1]});
123123
d_table_value->mutable_data<T>(context.GetPlace());
124124

125125
d_table->set_height(table_dim[0]);
126126

127127
auto *d_output_data = d_output->data<T>();
128128
auto *d_table_data = d_table_value->data<T>();
129129

130-
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
130+
auto d_output_dims = d_output->dims();
131+
PADDLE_ENFORCE_EQ(
132+
d_table_value->dims(),
133+
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
131134
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
132135
} else {
133136
auto *ids = context.Input<LoDTensor>("Ids");
134137
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
135138
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
136139

137140
auto *ids_data = ids->data<int64_t>();
138-
auto ids_dim = ids->dims();
139141

140142
int N = table_dim[0];
141-
int D = d_output->dims()[1];
143+
int D = table_dim[1];
142144

143145
auto *d_output_data = d_output->data<T>();
144146
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());

paddle/fluid/operators/softmax_cudnn_op.cu.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,16 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
3030
// allocate memory on device.
3131
Out->mutable_data<T>(context.GetPlace());
3232

33+
auto dims = X->dims();
34+
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
35+
framework::LoDTensor flattened_x;
36+
framework::LoDTensor flattened_out;
37+
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
38+
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
39+
3340
math::SoftmaxCUDNNFunctor<T>()(
34-
context.template device_context<platform::CUDADeviceContext>(), X, Out);
41+
context.template device_context<platform::CUDADeviceContext>(),
42+
&flattened_x, &flattened_out);
3543
}
3644
};
3745

@@ -46,9 +54,18 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
4654
// allocate memory on device.
4755
dX->mutable_data<T>(context.GetPlace());
4856

57+
auto dims = Out->dims();
58+
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
59+
framework::LoDTensor flattened_out;
60+
framework::LoDTensor flattened_d_out;
61+
framework::LoDTensor flattened_d_x;
62+
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
63+
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
64+
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
65+
4966
math::SoftmaxGradCUDNNFunctor<T>()(
50-
context.template device_context<platform::CUDADeviceContext>(), Out,
51-
dOut, dX);
67+
context.template device_context<platform::CUDADeviceContext>(),
68+
&flattened_out, &flattened_d_out, &flattened_d_x);
5269
}
5370
};
5471

paddle/fluid/operators/softmax_mkldnn_op.cc

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ using paddle::platform::MKLDNNMemDesc;
2626

2727
using mkldnn::memory; // Note: paddle has also "memory" namespace
2828
using mkldnn::primitive;
29-
using mkldnn::softmax_forward;
30-
using mkldnn::softmax_backward;
3129
using mkldnn::prop_kind;
30+
using mkldnn::softmax_backward;
31+
using mkldnn::softmax_forward;
3232
using mkldnn::stream;
3333
using platform::to_void_cast;
3434

@@ -113,17 +113,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
113113
auto mkldnn_engine = dev_ctx.GetEngine();
114114
const Tensor* input = ctx.Input<Tensor>("X");
115115
Tensor* output = ctx.Output<Tensor>("Out");
116-
PADDLE_ENFORCE(input->dims().size() == 2UL,
117-
"The input of softmax op must be a 2D matrix.");
118-
const T* input_data = input->data<T>();
119-
// allocate memory for output
120-
T* output_data = output->mutable_data<T>(ctx.GetPlace());
121-
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
122-
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
123-
// MKL-DNN does support softmax over selected axis. Having 2D Tensor,
124-
// we will make normalization after final eg. axis: 1
125-
PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])),
126-
"Softmax input and output dimensions should match");
116+
PADDLE_ENFORCE_EQ(
117+
input->dims(), output->dims(),
118+
"The shape of softmax's input and output must be identical.");
119+
120+
// make sure 'output' holds memory, which will be shared by
121+
// 'flattened_output' later.
122+
output->mutable_data<T>(ctx.GetPlace());
123+
124+
// flatten input and output to 2-D matrixs
125+
auto dims = input->dims(); // input and output share the same shape
126+
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
127+
framework::Tensor flattened_input;
128+
framework::Tensor flattened_output;
129+
flattened_input.ShareDataWith(*input).Resize(flattened_dims);
130+
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
131+
132+
const T* input_data = flattened_input.data<T>();
133+
T* output_data = flattened_output.mutable_data<T>(ctx.GetPlace());
134+
135+
std::vector<int> src_tz = paddle::framework::vectorize2int(flattened_dims);
136+
std::vector<int> dst_tz = src_tz;
127137
// Same memory descriptor to be used for input and output
128138
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
129139
// Generate keys for storing/retriving primitives for this operator
@@ -174,23 +184,34 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
174184
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
175185
auto mkldnn_engine = dev_ctx.GetEngine();
176186
const Tensor* output = ctx.Input<Tensor>("Out");
177-
const T* dst_data = output->data<T>();
178-
179187
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
180-
const auto* diff_dst_ptr = dout->template data<T>();
181-
182188
auto* dx =
183189
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
184-
T* diff_src_ptr = dx->template mutable_data<T>(ctx.GetPlace());
185190

186-
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
191+
PADDLE_ENFORCE_EQ(
192+
dout->dims(), dx->dims(),
193+
"The shape of softmax_grad's input and output must be identical.");
194+
195+
// make sure 'dx' holds memory, which will be shared by 'flattened_dx'
196+
// later.
197+
dx->template mutable_data<T>(ctx.GetPlace());
198+
199+
auto dims = dout->dims(); // input and output share the same shape
200+
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
201+
framework::Tensor flattened_output;
202+
framework::Tensor flattened_dout;
203+
framework::Tensor flattened_dx;
204+
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
205+
flattened_dout.ShareDataWith(*dout).Resize(flattened_dims);
206+
flattened_dx.ShareDataWith(*dx).Resize(flattened_dims);
207+
208+
const T* dst_data = flattened_output.data<T>();
209+
const T* diff_dst_ptr = flattened_dout.template data<T>();
210+
T* diff_src_ptr = flattened_dx.template mutable_data<T>(ctx.GetPlace());
211+
212+
std::vector<int> dst_tz = paddle::framework::vectorize2int(flattened_dims);
187213
std::vector<int> src_tz(dst_tz);
188-
PADDLE_ENFORCE(output->dims().size() == 2UL,
189-
"The input of softmax op must be a 2D matrix.");
190-
// MKL-DNN does support softmax over selected axis. Having 2D Tensor,
191-
// we will make normalization after final eg. axis: 1
192-
PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])),
193-
"Softmax input and output dimensions should match");
214+
194215
// Same memory descriptor to be used for input and output
195216
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
196217
// Currently only supports NC data format

paddle/fluid/operators/softmax_op.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
3737
PADDLE_ENFORCE(ctx->HasOutput("Out"),
3838
"Output(Out) of SoftmaxOp should not be null.");
3939

40-
auto x_dims = ctx->GetInputDim("X");
41-
PADDLE_ENFORCE(x_dims.size() == 2UL,
42-
"The input of softmax op must be a matrix.");
43-
ctx->SetOutputDim("Out", x_dims);
40+
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
4441
ctx->ShareLoD("X", /*->*/ "Out");
4542
}
4643

@@ -81,8 +78,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
8178
public:
8279
void Make() override {
8380
AddInput("X",
84-
"The input tensor of softmax. "
85-
"2-D with shape [batch_size, input_feature_dimensions].");
81+
"The input tensor of softmax, "
82+
"whose last dimension is the input_feature_dimensions.");
8683
AddOutput("Out", "The normalized values with the same shape as X.")
8784
.Reuse("X");
8885
AddAttr<bool>(
@@ -105,20 +102,23 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
105102
AddComment(R"DOC(
106103
Softmax Operator.
107104
108-
The input of the softmax operator is a 2-D tensor with shape N x K (N is the
109-
batch_size, K is the dimension of input feature). The output tensor has the
110-
same shape as the input tensor.
105+
The input of the softmax operator is a tensor of any rank. The output tensor
106+
has the same shape as the input.
111107
112-
For each row of the input tensor, the softmax operator squashes the
113-
K-dimensional vector of arbitrary real values to a K-dimensional vector of real
114-
values in the range [0, 1] that add up to 1.
108+
The input tensor will first be logically flattened to a 2-D matrix. The matrix's
109+
second dimension(row length) is as same as the last dimension of the input
110+
tensor, and the first dimension(column length) is the product of all other
111+
dimensions of the input tensor. For each row of the matrix, the softmax operator
112+
squashes the K-dimensional(K is the width of the matrix, which is also the size
113+
of the input tensor's last dimension) vector of arbitrary real values to a
114+
K-dimensional vector of real values in the range [0, 1] that add up to 1.
115115
It computes the exponential of the given dimension and the sum of exponential
116116
values of all the other dimensions in the K-dimensional vector input.
117117
Then the ratio of the exponential of the given dimension and the sum of
118118
exponential values of all the other dimensions is the output of the softmax
119119
operator.
120120
121-
For each row $i$ and each column $j$ in Input(X), we have:
121+
For each row $i$ and each column $j$ in the matrix, we have:
122122
$$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
123123
124124
)DOC");

paddle/fluid/operators/softmax_op.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,16 @@ class SoftmaxKernel : public framework::OpKernel<T> {
3131
// allocate memory on device.
3232
Out->mutable_data<T>(context.GetPlace());
3333

34+
auto dims = X->dims();
35+
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
36+
framework::LoDTensor flattened_x;
37+
framework::LoDTensor flattened_out;
38+
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
39+
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
40+
3441
math::SoftmaxFunctor<DeviceContext, T>()(
35-
context.template device_context<DeviceContext>(), X, Out);
42+
context.template device_context<DeviceContext>(), &flattened_x,
43+
&flattened_out);
3644
}
3745
};
3846

@@ -47,8 +55,18 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
4755
// allocate memory on device.
4856
dX->mutable_data<T>(context.GetPlace());
4957

58+
auto dims = Out->dims();
59+
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
60+
framework::LoDTensor flattened_out;
61+
framework::LoDTensor flattened_d_out;
62+
framework::LoDTensor flattened_d_x;
63+
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
64+
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
65+
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
66+
5067
math::SoftmaxGradFunctor<DeviceContext, T>()(
51-
context.template device_context<DeviceContext>(), Out, dOut, dX);
68+
context.template device_context<DeviceContext>(), &flattened_out,
69+
&flattened_d_out, &flattened_d_x);
5270
}
5371
};
5472

python/paddle/fluid/layers/nn.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,21 +1313,24 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
13131313

13141314
def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True, name=None):
13151315
"""
1316-
The input of the softmax layer is a 2-D tensor with shape N x K (N is the
1317-
batch_size, K is the dimension of input feature). The output tensor has the
1318-
same shape as the input tensor.
1316+
The input of the softmax operator is a tensor of any rank. The output tensor
1317+
has the same shape as the input.
13191318
1320-
For each row of the input tensor, the softmax operator squashes the
1321-
K-dimensional vector of arbitrary real values to a K-dimensional vector of real
1322-
values in the range [0, 1] that add up to 1.
1319+
The input tensor will first be logically flattened to a 2-D matrix. The matrix's
1320+
second dimension(row length) is as same as the last dimension of the input
1321+
tensor, and the first dimension(column length) is the product of all other
1322+
dimensions of the input tensor. For each row of the matrix, the softmax operator
1323+
squashes the K-dimensional(K is the width of the matrix, which is also the size
1324+
of the input tensor's last dimension) vector of arbitrary real values to a
1325+
K-dimensional vector of real values in the range [0, 1] that add up to 1.
13231326
13241327
It computes the exponential of the given dimension and the sum of exponential
13251328
values of all the other dimensions in the K-dimensional vector input.
13261329
Then the ratio of the exponential of the given dimension and the sum of
13271330
exponential values of all the other dimensions is the output of the softmax
13281331
operator.
13291332
1330-
For each row :math:`i` and each column :math:`j` in Input(X), we have:
1333+
For each row :math:`i` and each column :math:`j` in the matrix, we have:
13311334
13321335
.. math::
13331336

0 commit comments

Comments
 (0)