Skip to content

Commit e9695f4

Browse files
authored
Merge pull request #5014 from peterzhang2029/bi_tensor_prod_op
Add Bilinear Tensor Product operator.
2 parents 05c0908 + c5d7107 commit e9695f4

File tree

4 files changed

+406
-0
lines changed

4 files changed

+406
-0
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/bilinear_tensor_product_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using framework::Tensor;
21+
22+
class BilinearTensorProductOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
protected:
27+
void InferShape(framework::InferShapeContext* ctx) const override {
28+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
29+
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
30+
PADDLE_ENFORCE(ctx->HasInput("Weight"),
31+
"Input(Weight) should not be null.");
32+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
33+
auto x_dims = ctx->GetInputDim("X");
34+
auto y_dims = ctx->GetInputDim("Y");
35+
auto weight_dims = ctx->GetInputDim("Weight");
36+
37+
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input(X) must be a 2D Tensor.");
38+
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input(Y) must be a 2D Tensor.");
39+
PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL,
40+
"The input(Weight) must be a 3D tensor.");
41+
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
42+
"The first dimension(batch_size) of input(X) must be "
43+
"equal to the first dimension of the input(Y).");
44+
PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1],
45+
"The second dimension of input(X) must be equal to "
46+
"the second dimension of the input(Weight).");
47+
PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2],
48+
"The second dimension of input(Y) must be equal to "
49+
"the third dimension of the input(Weight).");
50+
51+
if (ctx->HasInput("Bias")) {
52+
auto bias_dims = ctx->GetInputDim("Bias");
53+
PADDLE_ENFORCE(bias_dims.size() == 2UL && bias_dims[0] == 1UL,
54+
"The Input(Bias) must be a 2-D tensor with "
55+
"the 2nd dimension fixed to 1 (a row vector).");
56+
PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0],
57+
"The second dimension of input(Bias) must be equal "
58+
"to the first dimension of the input(Weight).");
59+
}
60+
61+
ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]});
62+
ctx->ShareLoD("X", /*->*/ "Out");
63+
}
64+
};
65+
66+
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
67+
public:
68+
BilinearTensorProductOpMaker(framework::OpProto* proto,
69+
framework::OpAttrChecker* op_checker)
70+
: OpProtoAndCheckerMaker(proto, op_checker) {
71+
AddInput("X", "The first input of bilinear_tensor_product operator.");
72+
AddInput("Y", "The second input of bilinear_tensor_product operator.");
73+
AddInput("Weight",
74+
"The learnable parameters of bilinear_tensor_product operator.");
75+
AddInput("Bias", "The learnable bias of bilinear_tensor_product operator.")
76+
.AsDispensable();
77+
AddOutput("Out", "The output of bilinear_tensor_product operator.");
78+
AddComment(R"DOC(
79+
Bilinear Tensor Product operator.
80+
Given input X and Y, a 3D tensor weight, and bias. Each column of the
81+
output is computed by one slice i = 1, . . . , k of the tensor:
82+
83+
M = (X W_i) \cdot Y
84+
Out_i = \sum_i {M_i} + Bias_i
85+
86+
)DOC");
87+
}
88+
};
89+
90+
class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
91+
public:
92+
using framework::OperatorWithKernel::OperatorWithKernel;
93+
94+
protected:
95+
void InferShape(framework::InferShapeContext* ctx) const override {
96+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
97+
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
98+
PADDLE_ENFORCE(ctx->HasInput("Weight"),
99+
"Input(Weight) should not be null.");
100+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
101+
"Input(Out@GRAD) should not be null.");
102+
auto x_dims = ctx->GetInputDim("X");
103+
auto y_dims = ctx->GetInputDim("Y");
104+
auto weight_dims = ctx->GetInputDim("Weight");
105+
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
106+
107+
PADDLE_ENFORCE_EQ(out_dims.size(), 2UL,
108+
"The input(Out@GRAD) must be a 2D Tensor.");
109+
PADDLE_ENFORCE_EQ(
110+
x_dims[0], out_dims[0],
111+
"The first dimension(batch_size) of input(Out@GRAD) must be "
112+
"equal to the first dimension of the Input(X).");
113+
PADDLE_ENFORCE_EQ(
114+
weight_dims[0], out_dims[1],
115+
"The second dimension of input(Out@GRAD) must be equal to "
116+
"the third dimension of the Input(Weight).");
117+
118+
if (ctx->HasInput("Bias")) {
119+
auto bias_dims = ctx->GetInputDim("Bias");
120+
PADDLE_ENFORCE_EQ(
121+
bias_dims[1], out_dims[1],
122+
"The second dimension of input(Out@GRAD) must be equal to "
123+
"the second dimension of the Input(Bias).");
124+
auto bias_grad_name = framework::GradVarName("Bias");
125+
if (ctx->HasOutput(bias_grad_name))
126+
ctx->SetOutputDim(bias_grad_name, bias_dims);
127+
}
128+
129+
auto x_grad_name = framework::GradVarName("X");
130+
auto y_grad_name = framework::GradVarName("Y");
131+
auto weight_grad_name = framework::GradVarName("Weight");
132+
133+
if (ctx->HasOutput(x_grad_name)) {
134+
ctx->SetOutputDim(x_grad_name, x_dims);
135+
}
136+
if (ctx->HasOutput(y_grad_name)) {
137+
ctx->SetOutputDim(y_grad_name, y_dims);
138+
}
139+
if (ctx->HasOutput(weight_grad_name)) {
140+
ctx->SetOutputDim(weight_grad_name, weight_dims);
141+
}
142+
}
143+
};
144+
145+
} // namespace operators
146+
} // namespace paddle
147+
148+
namespace ops = paddle::operators;
149+
REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp,
150+
ops::BilinearTensorProductOpMaker, bilinear_tensor_product_grad,
151+
ops::BilinearTensorProductOpGrad);
152+
REGISTER_OP_CPU_KERNEL(
153+
bilinear_tensor_product,
154+
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float>,
155+
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, double>);
156+
REGISTER_OP_CPU_KERNEL(
157+
bilinear_tensor_product_grad,
158+
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float>,
159+
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, double>);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/bilinear_tensor_product_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_GPU_KERNEL(
20+
bilinear_tensor_product,
21+
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, float>,
22+
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, double>);
23+
REGISTER_OP_GPU_KERNEL(
24+
bilinear_tensor_product_grad,
25+
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, float>,
26+
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, double>);
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/framework/eigen.h"
18+
#include "paddle/framework/op_registry.h"
19+
#include "paddle/operators/math/math_function.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using framework::Tensor;
25+
26+
template <typename T, int MajorType = Eigen::RowMajor,
27+
typename IndexType = Eigen::DenseIndex>
28+
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
29+
30+
template <typename Place, typename T>
31+
class BilinearTensorProductKernel : public framework::OpKernel<T> {
32+
public:
33+
void Compute(const framework::ExecutionContext& ctx) const override {
34+
auto* x = ctx.Input<Tensor>("X");
35+
auto* y = ctx.Input<Tensor>("Y");
36+
auto* weight = ctx.Input<Tensor>("Weight");
37+
auto* bias = ctx.Input<Tensor>("Bias");
38+
auto* out = ctx.Output<Tensor>("Out");
39+
out->mutable_data<T>(ctx.GetPlace());
40+
41+
auto y_mat = EigenMatrix<T>::From(*y);
42+
auto output_mat = EigenMatrix<T>::From(*out);
43+
44+
auto batch_size = x->dims()[0];
45+
auto weight_dims = weight->dims();
46+
int out_dim = weight_dims[0];
47+
auto x_dim = weight_dims[1];
48+
auto y_dim = weight_dims[2];
49+
auto place = ctx.GetEigenDevice<Place>();
50+
51+
// Create the intermediate variable to caculate the result of
52+
// Input(X) multiplied by Input(Weight_i), the formula is:
53+
// left_mul = X Weight_i.
54+
Tensor left_mul;
55+
left_mul.mutable_data<T>(framework::make_ddim({batch_size, y_dim}),
56+
ctx.GetPlace());
57+
auto left_mul_mat = EigenMatrix<T>::From(left_mul);
58+
59+
for (int i = 0; i < out_dim; ++i) {
60+
auto output_col_vec = output_mat.chip(i, 1);
61+
Tensor weight_mat =
62+
weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim}));
63+
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
64+
batch_size, y_dim, x_dim, 1, x->data<T>(),
65+
weight_mat.data<T>(), 0, left_mul.data<T>());
66+
output_col_vec.device(place) =
67+
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
68+
}
69+
if (bias) {
70+
auto bias_vec = EigenMatrix<T>::From(*bias);
71+
Eigen::DSizes<int, 2> bcast(batch_size, 1);
72+
output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat;
73+
}
74+
}
75+
};
76+
77+
template <typename Place, typename T>
78+
class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
79+
public:
80+
void Compute(const framework::ExecutionContext& ctx) const override {
81+
const Tensor* x = ctx.Input<Tensor>("X");
82+
const Tensor* y = ctx.Input<Tensor>("Y");
83+
const Tensor* weight = ctx.Input<Tensor>("Weight");
84+
Tensor* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
85+
Tensor* d_y = ctx.Output<Tensor>(framework::GradVarName("Y"));
86+
Tensor* d_weight = ctx.Output<Tensor>(framework::GradVarName("Weight"));
87+
Tensor* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
88+
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
89+
90+
auto batch_size = x->dims()[0];
91+
auto weight_dims = weight->dims();
92+
int out_dim = weight_dims[0];
93+
auto x_dim = weight_dims[1];
94+
auto y_dim = weight_dims[2];
95+
96+
auto x_mat = EigenMatrix<T>::From(*x);
97+
auto y_mat = EigenMatrix<T>::From(*y);
98+
auto d_out_mat = EigenMatrix<T>::From(*d_out);
99+
auto place = ctx.GetEigenDevice<Place>();
100+
101+
// Create the intermediate variable to caculate the Output(Y@Grad).
102+
Tensor x_scale;
103+
x_scale.mutable_data<T>(framework::make_ddim({batch_size, x_dim}),
104+
ctx.GetPlace());
105+
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
106+
107+
// Create the intermediate variable to caculate the Output(X@Grad).
108+
Tensor y_scale;
109+
y_scale.mutable_data<T>(framework::make_ddim({batch_size, y_dim}),
110+
ctx.GetPlace());
111+
auto y_scale_mat = EigenMatrix<T>::From(y_scale);
112+
113+
math::SetConstant<Place, T> set_zero;
114+
115+
// Set Output(X@Grad) be zero.
116+
if (d_x) {
117+
d_x->mutable_data<T>(ctx.GetPlace());
118+
set_zero(ctx.device_context(), d_x, static_cast<T>(0));
119+
}
120+
121+
// Set Output(Y@Grad) be zero.
122+
if (d_y) {
123+
d_y->mutable_data<T>(ctx.GetPlace());
124+
set_zero(ctx.device_context(), d_y, static_cast<T>(0));
125+
}
126+
127+
// Caculate the Output(X@Grad) and Output(Y@Grad).
128+
if (d_x || d_y) {
129+
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
130+
Eigen::DSizes<int, 2> bcast_for_y(1, x_dim);
131+
for (int i = 0; i < out_dim; ++i) {
132+
Tensor weight_i = weight->Slice(i, i + 1).Resize(
133+
framework::make_ddim({x_dim, y_dim}));
134+
auto output_vec = d_out_mat.chip(i, 1);
135+
if (d_x) {
136+
y_scale_mat.device(place) =
137+
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
138+
.broadcast(bcast_for_x) *
139+
y_mat;
140+
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans,
141+
batch_size, x_dim, y_dim, 1, y_scale.data<T>(),
142+
weight_i.data<T>(), 1, d_x->data<T>());
143+
}
144+
if (d_y) {
145+
x_scale_mat.device(place) =
146+
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
147+
.broadcast(bcast_for_y) *
148+
x_mat;
149+
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
150+
batch_size, y_dim, x_dim, 1, x_scale.data<T>(),
151+
weight_i.data<T>(), 1, d_y->data<T>());
152+
}
153+
}
154+
}
155+
156+
// Caculate the gradient of Input(Weight).
157+
if (d_weight) {
158+
d_weight->mutable_data<T>(ctx.GetPlace());
159+
Eigen::DSizes<int, 2> bcast_for_weight(1, x_dim);
160+
for (int i = 0; i < out_dim; ++i) {
161+
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
162+
framework::make_ddim({x_dim, y_dim}));
163+
auto output_vec = d_out_mat.chip(i, 1);
164+
x_scale_mat.device(place) =
165+
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
166+
.broadcast(bcast_for_weight) *
167+
x_mat;
168+
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
169+
x_dim, y_dim, batch_size, 1, x_scale.data<T>(),
170+
y->data<T>(), 0, d_weight_i.data<T>());
171+
}
172+
}
173+
174+
// Caculate the gradient of Input(Bias).
175+
if (d_bias) {
176+
d_bias->mutable_data<T>(ctx.GetPlace());
177+
auto d_bias_mat = EigenMatrix<T>::From(*d_bias);
178+
d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes<int, 1>(0));
179+
}
180+
}
181+
};
182+
183+
} // namespace operators
184+
} // namespace paddle

0 commit comments

Comments
 (0)