@@ -43,25 +43,25 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
43
43
44
44
auto batch_size = x->dims ()[0 ];
45
45
auto weight_dims = weight->dims ();
46
- int Out_dim = weight_dims[0 ];
47
- int X_dim = weight_dims[1 ];
48
- int Y_dim = weight_dims[2 ];
46
+ int out_dim = weight_dims[0 ];
47
+ auto x_dim = weight_dims[1 ];
48
+ auto y_dim = weight_dims[2 ];
49
49
auto place = ctx.GetEigenDevice <Place>();
50
50
51
51
// Create the intermediate variable to caculate the result of
52
52
// Input(X) multiplied by Input(Weight_i), the formula is:
53
53
// left_mul = X Weight_i.
54
54
Tensor left_mul;
55
- left_mul.mutable_data <T>(framework::make_ddim ({batch_size, Y_dim }),
55
+ left_mul.mutable_data <T>(framework::make_ddim ({batch_size, y_dim }),
56
56
ctx.GetPlace ());
57
57
auto left_mul_mat = EigenMatrix<T>::From (left_mul);
58
58
59
- for (int i = 0 ; i < Out_dim ; ++i) {
59
+ for (int i = 0 ; i < out_dim ; ++i) {
60
60
auto output_col_vec = output_mat.chip (i, 1 );
61
61
Tensor weight_mat =
62
- weight->Slice (i, i + 1 ).Resize (framework::make_ddim ({X_dim, Y_dim }));
62
+ weight->Slice (i, i + 1 ).Resize (framework::make_ddim ({x_dim, y_dim }));
63
63
math::gemm<Place, T>(ctx.device_context (), CblasNoTrans, CblasNoTrans,
64
- batch_size, Y_dim, X_dim , 1 , x->data <T>(),
64
+ batch_size, y_dim, x_dim , 1 , x->data <T>(),
65
65
weight_mat.data <T>(), 0 , left_mul.data <T>());
66
66
output_col_vec.device (place) =
67
67
(left_mul_mat * y_mat).sum (Eigen::DSizes<int , 1 >(1 ));
@@ -89,9 +89,9 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
89
89
90
90
auto batch_size = x->dims ()[0 ];
91
91
auto weight_dims = weight->dims ();
92
- int Out_dim = weight_dims[0 ];
93
- int X_dim = weight_dims[1 ];
94
- int Y_dim = weight_dims[2 ];
92
+ int out_dim = weight_dims[0 ];
93
+ auto x_dim = weight_dims[1 ];
94
+ auto y_dim = weight_dims[2 ];
95
95
96
96
auto x_mat = EigenMatrix<T>::From (*x);
97
97
auto y_mat = EigenMatrix<T>::From (*y);
@@ -100,13 +100,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
100
100
101
101
// Create the intermediate variable to caculate the Output(Y@Grad).
102
102
Tensor x_scale;
103
- x_scale.mutable_data <T>(framework::make_ddim ({batch_size, X_dim }),
103
+ x_scale.mutable_data <T>(framework::make_ddim ({batch_size, x_dim }),
104
104
ctx.GetPlace ());
105
105
auto x_scale_mat = EigenMatrix<T>::From (x_scale);
106
106
107
107
// Create the intermediate variable to caculate the Output(X@Grad).
108
108
Tensor y_scale;
109
- y_scale.mutable_data <T>(framework::make_ddim ({batch_size, Y_dim }),
109
+ y_scale.mutable_data <T>(framework::make_ddim ({batch_size, y_dim }),
110
110
ctx.GetPlace ());
111
111
auto y_scale_mat = EigenMatrix<T>::From (y_scale);
112
112
@@ -126,19 +126,19 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
126
126
127
127
// Caculate the Output(X@Grad) and Output(Y@Grad).
128
128
if (d_x || d_y) {
129
- Eigen::DSizes<int , 2 > bcast_for_x (1 , Y_dim );
130
- Eigen::DSizes<int , 2 > bcast_for_y (1 , X_dim );
131
- for (int i = 0 ; i < Out_dim ; ++i) {
129
+ Eigen::DSizes<int , 2 > bcast_for_x (1 , y_dim );
130
+ Eigen::DSizes<int , 2 > bcast_for_y (1 , x_dim );
131
+ for (int i = 0 ; i < out_dim ; ++i) {
132
132
Tensor weight_i = weight->Slice (i, i + 1 ).Resize (
133
- framework::make_ddim ({X_dim, Y_dim }));
133
+ framework::make_ddim ({x_dim, y_dim }));
134
134
auto output_vec = d_out_mat.chip (i, 1 );
135
135
if (d_x) {
136
136
y_scale_mat.device (place) =
137
137
output_vec.reshape (Eigen::DSizes<int , 2 >(batch_size, 1 ))
138
138
.broadcast (bcast_for_x) *
139
139
y_mat;
140
140
math::gemm<Place, T>(ctx.device_context (), CblasNoTrans, CblasTrans,
141
- batch_size, X_dim, Y_dim , 1 , y_scale.data <T>(),
141
+ batch_size, x_dim, y_dim , 1 , y_scale.data <T>(),
142
142
weight_i.data <T>(), 1 , d_x->data <T>());
143
143
}
144
144
if (d_y) {
@@ -147,7 +147,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
147
147
.broadcast (bcast_for_y) *
148
148
x_mat;
149
149
math::gemm<Place, T>(ctx.device_context (), CblasNoTrans, CblasNoTrans,
150
- batch_size, Y_dim, X_dim , 1 , x_scale.data <T>(),
150
+ batch_size, y_dim, x_dim , 1 , x_scale.data <T>(),
151
151
weight_i.data <T>(), 1 , d_y->data <T>());
152
152
}
153
153
}
@@ -156,17 +156,17 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
156
156
// Caculate the gradient of Input(Weight).
157
157
if (d_weight) {
158
158
d_weight->mutable_data <T>(ctx.GetPlace ());
159
- Eigen::DSizes<int , 2 > bcast_for_weight (1 , X_dim );
160
- for (int i = 0 ; i < Out_dim ; ++i) {
159
+ Eigen::DSizes<int , 2 > bcast_for_weight (1 , x_dim );
160
+ for (int i = 0 ; i < out_dim ; ++i) {
161
161
Tensor d_weight_i = d_weight->Slice (i, i + 1 ).Resize (
162
- framework::make_ddim ({X_dim, Y_dim }));
162
+ framework::make_ddim ({x_dim, y_dim }));
163
163
auto output_vec = d_out_mat.chip (i, 1 );
164
164
x_scale_mat.device (place) =
165
165
output_vec.reshape (Eigen::DSizes<int , 2 >(batch_size, 1 ))
166
166
.broadcast (bcast_for_weight) *
167
167
x_mat;
168
168
math::gemm<Place, T>(ctx.device_context (), CblasTrans, CblasNoTrans,
169
- X_dim, Y_dim , batch_size, 1 , x_scale.data <T>(),
169
+ x_dim, y_dim , batch_size, 1 , x_scale.data <T>(),
170
170
y->data <T>(), 0 , d_weight_i.data <T>());
171
171
}
172
172
}
0 commit comments