Skip to content

Commit 36e26a5

Browse files
jacquesqiaoqingqing01
authored andcommitted
Optimize bilinear tensor product op (#14485)
* optimize bilinear_tensor_product * add set zero to set grad to 0.
1 parent 4ec9de0 commit 36e26a5

File tree

1 file changed

+30
-31
lines changed

1 file changed

+30
-31
lines changed

paddle/fluid/operators/bilinear_tensor_product_op.h

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
7070
if (bias) {
7171
auto bias_vec = EigenMatrix<T>::From(*bias);
7272
Eigen::DSizes<int, 2> bcast(batch_size, 1);
73-
output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat;
73+
output_mat.device(place) = bias_vec.broadcast(bcast).eval() + output_mat;
7474
}
7575
}
7676
};
@@ -99,79 +99,78 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
9999
auto d_out_mat = EigenMatrix<T>::From(*d_out);
100100
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
101101
auto& dev_ctx = ctx.template device_context<DeviceContext>();
102-
// Create the intermediate variable to caculate the Output(Y@Grad).
102+
// Create the intermediate variable to calculate the Output(Y@Grad).
103103
Tensor x_scale;
104104
x_scale.mutable_data<T>(framework::make_ddim({batch_size, x_dim}),
105105
ctx.GetPlace());
106106
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
107107

108-
// Create the intermediate variable to caculate the Output(X@Grad).
108+
// Create the intermediate variable to calculate the Output(X@Grad).
109109
Tensor y_scale;
110110
y_scale.mutable_data<T>(framework::make_ddim({batch_size, y_dim}),
111111
ctx.GetPlace());
112112
auto y_scale_mat = EigenMatrix<T>::From(y_scale);
113113

114114
math::SetConstant<DeviceContext, T> set_zero;
115115

116-
// Set Output(X@Grad) be zero.
117116
if (d_x) {
118117
d_x->mutable_data<T>(ctx.GetPlace());
119118
set_zero(dev_ctx, d_x, static_cast<T>(0));
120119
}
121120

122-
// Set Output(Y@Grad) be zero.
123121
if (d_y) {
124122
d_y->mutable_data<T>(ctx.GetPlace());
125123
set_zero(dev_ctx, d_y, static_cast<T>(0));
126124
}
127125

126+
if (d_weight) {
127+
d_weight->mutable_data<T>(ctx.GetPlace());
128+
}
129+
128130
auto blas = math::GetBlas<DeviceContext, T>(ctx);
129131

130132
// Caculate the Output(X@Grad) and Output(Y@Grad).
131-
if (d_x || d_y) {
133+
if (d_x || d_y || d_weight) {
132134
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
133135
Eigen::DSizes<int, 2> bcast_for_y(1, x_dim);
136+
Eigen::DSizes<int, 2> bcast_for_weight(1, x_dim);
137+
134138
for (int i = 0; i < out_dim; ++i) {
135139
Tensor weight_i = weight->Slice(i, i + 1).Resize(
136140
framework::make_ddim({x_dim, y_dim}));
137141
auto output_vec = d_out_mat.chip(i, 1);
142+
138143
if (d_x) {
139144
y_scale_mat.device(place) =
140145
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
141-
.broadcast(bcast_for_x) *
146+
.broadcast(bcast_for_x)
147+
.eval() *
142148
y_mat;
143149
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
144150
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
145151
}
146-
if (d_y) {
147-
x_scale_mat.device(place) =
152+
153+
if (d_y || d_weight) {
154+
auto output_vec_y =
148155
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
149-
.broadcast(bcast_for_y) *
150-
x_mat;
151-
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
152-
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
156+
.broadcast(bcast_for_y)
157+
.eval();
158+
x_scale_mat.device(place) = output_vec_y * x_mat;
159+
if (d_y) {
160+
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
161+
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
162+
}
163+
if (d_weight) {
164+
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
165+
framework::make_ddim({x_dim, y_dim}));
166+
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
167+
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
168+
}
153169
}
154170
}
155171
}
156172

157-
// Caculate the gradient of Input(Weight).
158-
if (d_weight) {
159-
d_weight->mutable_data<T>(ctx.GetPlace());
160-
Eigen::DSizes<int, 2> bcast_for_weight(1, x_dim);
161-
for (int i = 0; i < out_dim; ++i) {
162-
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
163-
framework::make_ddim({x_dim, y_dim}));
164-
auto output_vec = d_out_mat.chip(i, 1);
165-
x_scale_mat.device(place) =
166-
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
167-
.broadcast(bcast_for_weight) *
168-
x_mat;
169-
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
170-
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
171-
}
172-
}
173-
174-
// Caculate the gradient of Input(Bias).
173+
// calculate the gradient of Input(Bias).
175174
if (d_bias) {
176175
d_bias->mutable_data<T>(ctx.GetPlace());
177176
auto d_bias_mat = framework::EigenVector<T>::Flatten(*d_bias);

0 commit comments

Comments
 (0)