Skip to content

Commit 0a6262d

Browse files
fix warning
1 parent 665eb01 commit 0a6262d

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

paddle/operators/bilinear_tensor_product_op.h

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,26 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
4343

4444
auto batch_size = x->dims()[0];
4545
auto weight_dims = weight->dims();
46+
int Out_dim = weight_dims[0];
47+
int X_dim = weight_dims[1];
48+
int Y_dim = weight_dims[2];
4649
auto place = ctx.GetEigenDevice<Place>();
4750

4851
// Create the intermediate variable to caculate the result of
4952
// Input(X) multiplied by Input(Weight_i), the formula is:
5053
// left_mul = X Weight_i.
5154
Tensor left_mul;
52-
left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
55+
left_mul.mutable_data<T>(framework::make_ddim({batch_size, Y_dim}),
5356
ctx.GetPlace());
5457
auto left_mul_mat = EigenMatrix<T>::From(left_mul);
5558

56-
for (size_t i = 0; i < weight_dims[0]; ++i) {
59+
for (int i = 0; i < Out_dim; ++i) {
5760
auto output_col_vec = output_mat.chip(i, 1);
58-
Tensor weight_mat = weight->Slice(i, i + 1).Resize(
59-
framework::make_ddim({weight_dims[1], weight_dims[2]}));
61+
Tensor weight_mat =
62+
weight->Slice(i, i + 1).Resize(framework::make_ddim({X_dim, Y_dim}));
6063
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
61-
batch_size, weight_dims[2], weight_dims[1], 1,
62-
x->data<T>(), weight_mat.data<T>(), 0,
63-
left_mul.data<T>());
64+
batch_size, Y_dim, X_dim, 1, x->data<T>(),
65+
weight_mat.data<T>(), 0, left_mul.data<T>());
6466
output_col_vec.device(place) =
6567
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
6668
}
@@ -87,6 +89,9 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
8789

8890
auto batch_size = x->dims()[0];
8991
auto weight_dims = weight->dims();
92+
int Out_dim = weight_dims[0];
93+
int X_dim = weight_dims[1];
94+
int Y_dim = weight_dims[2];
9095

9196
auto x_mat = EigenMatrix<T>::From(*x);
9297
auto y_mat = EigenMatrix<T>::From(*y);
@@ -95,13 +100,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
95100

96101
// Create the intermediate variable to caculate the Output(Y@Grad).
97102
Tensor x_scale;
98-
x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}),
103+
x_scale.mutable_data<T>(framework::make_ddim({batch_size, X_dim}),
99104
ctx.GetPlace());
100105
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
101106

102107
// Create the intermediate variable to caculate the Output(X@Grad).
103108
Tensor y_scale;
104-
y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
109+
y_scale.mutable_data<T>(framework::make_ddim({batch_size, Y_dim}),
105110
ctx.GetPlace());
106111
auto y_scale_mat = EigenMatrix<T>::From(y_scale);
107112

@@ -121,51 +126,48 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
121126

122127
// Caculate the Output(X@Grad) and Output(Y@Grad).
123128
if (d_x || d_y) {
124-
Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]);
125-
Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]);
126-
for (int i = 0; i < weight_dims[0]; ++i) {
129+
Eigen::DSizes<int, 2> bcast_for_x(1, Y_dim);
130+
Eigen::DSizes<int, 2> bcast_for_y(1, X_dim);
131+
for (int i = 0; i < Out_dim; ++i) {
127132
Tensor weight_i = weight->Slice(i, i + 1).Resize(
128-
framework::make_ddim({weight_dims[1], weight_dims[2]}));
133+
framework::make_ddim({X_dim, Y_dim}));
129134
auto output_vec = d_out_mat.chip(i, 1);
130135
if (d_x) {
131136
y_scale_mat.device(place) =
132137
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
133138
.broadcast(bcast_for_x) *
134139
y_mat;
135140
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans,
136-
batch_size, weight_dims[1], weight_dims[2], 1,
137-
y_scale.data<T>(), weight_i.data<T>(), 1,
138-
d_x->data<T>());
141+
batch_size, X_dim, Y_dim, 1, y_scale.data<T>(),
142+
weight_i.data<T>(), 1, d_x->data<T>());
139143
}
140144
if (d_y) {
141145
x_scale_mat.device(place) =
142146
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
143147
.broadcast(bcast_for_y) *
144148
x_mat;
145149
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
146-
batch_size, weight_dims[2], weight_dims[1], 1,
147-
x_scale.data<T>(), weight_i.data<T>(), 1,
148-
d_y->data<T>());
150+
batch_size, Y_dim, X_dim, 1, x_scale.data<T>(),
151+
weight_i.data<T>(), 1, d_y->data<T>());
149152
}
150153
}
151154
}
152155

153156
// Caculate the gradient of Input(Weight).
154157
if (d_weight) {
155158
d_weight->mutable_data<T>(ctx.GetPlace());
156-
Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]);
157-
for (int i = 0; i < weight_dims[0]; ++i) {
159+
Eigen::DSizes<int, 2> bcast_for_weight(1, X_dim);
160+
for (int i = 0; i < Out_dim; ++i) {
158161
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
159-
framework::make_ddim({weight_dims[1], weight_dims[2]}));
162+
framework::make_ddim({X_dim, Y_dim}));
160163
auto output_vec = d_out_mat.chip(i, 1);
161164
x_scale_mat.device(place) =
162165
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
163166
.broadcast(bcast_for_weight) *
164167
x_mat;
165168
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
166-
weight_dims[1], weight_dims[2], batch_size, 1,
167-
x_scale.data<T>(), y->data<T>(), 0,
168-
d_weight_i.data<T>());
169+
X_dim, Y_dim, batch_size, 1, x_scale.data<T>(),
170+
y->data<T>(), 0, d_weight_i.data<T>());
169171
}
170172
}
171173

0 commit comments

Comments
 (0)