@@ -43,24 +43,26 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
43
43
44
44
auto batch_size = x->dims ()[0 ];
45
45
auto weight_dims = weight->dims ();
46
+ int Out_dim = weight_dims[0 ];
47
+ int X_dim = weight_dims[1 ];
48
+ int Y_dim = weight_dims[2 ];
46
49
auto place = ctx.GetEigenDevice <Place>();
47
50
48
51
// Create the intermediate variable to caculate the result of
49
52
// Input(X) multiplied by Input(Weight_i), the formula is:
50
53
// left_mul = X Weight_i.
51
54
Tensor left_mul;
52
- left_mul.mutable_data <T>(framework::make_ddim ({batch_size, weight_dims[ 2 ] }),
55
+ left_mul.mutable_data <T>(framework::make_ddim ({batch_size, Y_dim }),
53
56
ctx.GetPlace ());
54
57
auto left_mul_mat = EigenMatrix<T>::From (left_mul);
55
58
56
- for (size_t i = 0 ; i < weight_dims[ 0 ] ; ++i) {
59
+ for (int i = 0 ; i < Out_dim ; ++i) {
57
60
auto output_col_vec = output_mat.chip (i, 1 );
58
- Tensor weight_mat = weight-> Slice (i, i + 1 ). Resize (
59
- framework::make_ddim ({weight_dims[ 1 ], weight_dims[ 2 ] }));
61
+ Tensor weight_mat =
62
+ weight-> Slice (i, i + 1 ). Resize ( framework::make_ddim ({X_dim, Y_dim }));
60
63
math::gemm<Place, T>(ctx.device_context (), CblasNoTrans, CblasNoTrans,
61
- batch_size, weight_dims[2 ], weight_dims[1 ], 1 ,
62
- x->data <T>(), weight_mat.data <T>(), 0 ,
63
- left_mul.data <T>());
64
+ batch_size, Y_dim, X_dim, 1 , x->data <T>(),
65
+ weight_mat.data <T>(), 0 , left_mul.data <T>());
64
66
output_col_vec.device (place) =
65
67
(left_mul_mat * y_mat).sum (Eigen::DSizes<int , 1 >(1 ));
66
68
}
@@ -87,6 +89,9 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
87
89
88
90
auto batch_size = x->dims ()[0 ];
89
91
auto weight_dims = weight->dims ();
92
+ int Out_dim = weight_dims[0 ];
93
+ int X_dim = weight_dims[1 ];
94
+ int Y_dim = weight_dims[2 ];
90
95
91
96
auto x_mat = EigenMatrix<T>::From (*x);
92
97
auto y_mat = EigenMatrix<T>::From (*y);
@@ -95,13 +100,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
95
100
96
101
// Create the intermediate variable to caculate the Output(Y@Grad).
97
102
Tensor x_scale;
98
- x_scale.mutable_data <T>(framework::make_ddim ({batch_size, weight_dims[ 1 ] }),
103
+ x_scale.mutable_data <T>(framework::make_ddim ({batch_size, X_dim }),
99
104
ctx.GetPlace ());
100
105
auto x_scale_mat = EigenMatrix<T>::From (x_scale);
101
106
102
107
// Create the intermediate variable to caculate the Output(X@Grad).
103
108
Tensor y_scale;
104
- y_scale.mutable_data <T>(framework::make_ddim ({batch_size, weight_dims[ 2 ] }),
109
+ y_scale.mutable_data <T>(framework::make_ddim ({batch_size, Y_dim }),
105
110
ctx.GetPlace ());
106
111
auto y_scale_mat = EigenMatrix<T>::From (y_scale);
107
112
@@ -121,51 +126,48 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
121
126
122
127
// Caculate the Output(X@Grad) and Output(Y@Grad).
123
128
if (d_x || d_y) {
124
- Eigen::DSizes<int , 2 > bcast_for_x (1 , weight_dims[ 2 ] );
125
- Eigen::DSizes<int , 2 > bcast_for_y (1 , weight_dims[ 1 ] );
126
- for (int i = 0 ; i < weight_dims[ 0 ] ; ++i) {
129
+ Eigen::DSizes<int , 2 > bcast_for_x (1 , Y_dim );
130
+ Eigen::DSizes<int , 2 > bcast_for_y (1 , X_dim );
131
+ for (int i = 0 ; i < Out_dim ; ++i) {
127
132
Tensor weight_i = weight->Slice (i, i + 1 ).Resize (
128
- framework::make_ddim ({weight_dims[ 1 ], weight_dims[ 2 ] }));
133
+ framework::make_ddim ({X_dim, Y_dim }));
129
134
auto output_vec = d_out_mat.chip (i, 1 );
130
135
if (d_x) {
131
136
y_scale_mat.device (place) =
132
137
output_vec.reshape (Eigen::DSizes<int , 2 >(batch_size, 1 ))
133
138
.broadcast (bcast_for_x) *
134
139
y_mat;
135
140
math::gemm<Place, T>(ctx.device_context (), CblasNoTrans, CblasTrans,
136
- batch_size, weight_dims[1 ], weight_dims[2 ], 1 ,
137
- y_scale.data <T>(), weight_i.data <T>(), 1 ,
138
- d_x->data <T>());
141
+ batch_size, X_dim, Y_dim, 1 , y_scale.data <T>(),
142
+ weight_i.data <T>(), 1 , d_x->data <T>());
139
143
}
140
144
if (d_y) {
141
145
x_scale_mat.device (place) =
142
146
output_vec.reshape (Eigen::DSizes<int , 2 >(batch_size, 1 ))
143
147
.broadcast (bcast_for_y) *
144
148
x_mat;
145
149
math::gemm<Place, T>(ctx.device_context (), CblasNoTrans, CblasNoTrans,
146
- batch_size, weight_dims[2 ], weight_dims[1 ], 1 ,
147
- x_scale.data <T>(), weight_i.data <T>(), 1 ,
148
- d_y->data <T>());
150
+ batch_size, Y_dim, X_dim, 1 , x_scale.data <T>(),
151
+ weight_i.data <T>(), 1 , d_y->data <T>());
149
152
}
150
153
}
151
154
}
152
155
153
156
// Caculate the gradient of Input(Weight).
154
157
if (d_weight) {
155
158
d_weight->mutable_data <T>(ctx.GetPlace ());
156
- Eigen::DSizes<int , 2 > bcast_for_weight (1 , weight_dims[ 1 ] );
157
- for (int i = 0 ; i < weight_dims[ 0 ] ; ++i) {
159
+ Eigen::DSizes<int , 2 > bcast_for_weight (1 , X_dim );
160
+ for (int i = 0 ; i < Out_dim ; ++i) {
158
161
Tensor d_weight_i = d_weight->Slice (i, i + 1 ).Resize (
159
- framework::make_ddim ({weight_dims[ 1 ], weight_dims[ 2 ] }));
162
+ framework::make_ddim ({X_dim, Y_dim }));
160
163
auto output_vec = d_out_mat.chip (i, 1 );
161
164
x_scale_mat.device (place) =
162
165
output_vec.reshape (Eigen::DSizes<int , 2 >(batch_size, 1 ))
163
166
.broadcast (bcast_for_weight) *
164
167
x_mat;
165
168
math::gemm<Place, T>(ctx.device_context (), CblasTrans, CblasNoTrans,
166
- weight_dims[1 ], weight_dims[2 ], batch_size, 1 ,
167
- x_scale.data <T>(), y->data <T>(), 0 ,
168
- d_weight_i.data <T>());
169
+ X_dim, Y_dim, batch_size, 1 , x_scale.data <T>(),
170
+ y->data <T>(), 0 , d_weight_i.data <T>());
169
171
}
170
172
}
171
173
0 commit comments