@@ -70,7 +70,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
70
70
if (bias) {
71
71
auto bias_vec = EigenMatrix<T>::From (*bias);
72
72
Eigen::DSizes<int , 2 > bcast (batch_size, 1 );
73
- output_mat.device (place) = bias_vec.broadcast (bcast) + output_mat;
73
+ output_mat.device (place) = bias_vec.broadcast (bcast). eval () + output_mat;
74
74
}
75
75
}
76
76
};
@@ -99,79 +99,78 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
99
99
auto d_out_mat = EigenMatrix<T>::From (*d_out);
100
100
auto & place = *ctx.template device_context <DeviceContext>().eigen_device ();
101
101
auto & dev_ctx = ctx.template device_context <DeviceContext>();
102
- // Create the intermediate variable to caculate the Output(Y@Grad).
102
+ // Create the intermediate variable to calculate the Output(Y@Grad).
103
103
Tensor x_scale;
104
104
x_scale.mutable_data <T>(framework::make_ddim ({batch_size, x_dim}),
105
105
ctx.GetPlace ());
106
106
auto x_scale_mat = EigenMatrix<T>::From (x_scale);
107
107
108
- // Create the intermediate variable to caculate the Output(X@Grad).
108
+ // Create the intermediate variable to calculate the Output(X@Grad).
109
109
Tensor y_scale;
110
110
y_scale.mutable_data <T>(framework::make_ddim ({batch_size, y_dim}),
111
111
ctx.GetPlace ());
112
112
auto y_scale_mat = EigenMatrix<T>::From (y_scale);
113
113
114
114
math::SetConstant<DeviceContext, T> set_zero;
115
115
116
- // Set Output(X@Grad) be zero.
117
116
if (d_x) {
118
117
d_x->mutable_data <T>(ctx.GetPlace ());
119
118
set_zero (dev_ctx, d_x, static_cast <T>(0 ));
120
119
}
121
120
122
- // Set Output(Y@Grad) be zero.
123
121
if (d_y) {
124
122
d_y->mutable_data <T>(ctx.GetPlace ());
125
123
set_zero (dev_ctx, d_y, static_cast <T>(0 ));
126
124
}
127
125
126
+ if (d_weight) {
127
+ d_weight->mutable_data <T>(ctx.GetPlace ());
128
+ }
129
+
128
130
auto blas = math::GetBlas<DeviceContext, T>(ctx);
129
131
130
132
// Caculate the Output(X@Grad) and Output(Y@Grad).
131
- if (d_x || d_y) {
133
+ if (d_x || d_y || d_weight ) {
132
134
Eigen::DSizes<int , 2 > bcast_for_x (1 , y_dim);
133
135
Eigen::DSizes<int , 2 > bcast_for_y (1 , x_dim);
136
+ Eigen::DSizes<int , 2 > bcast_for_weight (1 , x_dim);
137
+
134
138
for (int i = 0 ; i < out_dim; ++i) {
135
139
Tensor weight_i = weight->Slice (i, i + 1 ).Resize (
136
140
framework::make_ddim ({x_dim, y_dim}));
137
141
auto output_vec = d_out_mat.chip (i, 1 );
142
+
138
143
if (d_x) {
139
144
y_scale_mat.device (place) =
140
145
output_vec.reshape (Eigen::DSizes<int , 2 >(batch_size, 1 ))
141
- .broadcast (bcast_for_x) *
146
+ .broadcast (bcast_for_x)
147
+ .eval () *
142
148
y_mat;
143
149
blas.GEMM (CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1 ,
144
150
y_scale.data <T>(), weight_i.data <T>(), 1 , d_x->data <T>());
145
151
}
146
- if (d_y) {
147
- x_scale_mat.device (place) =
152
+
153
+ if (d_y || d_weight) {
154
+ auto output_vec_y =
148
155
output_vec.reshape (Eigen::DSizes<int , 2 >(batch_size, 1 ))
149
- .broadcast (bcast_for_y) *
150
- x_mat;
151
- blas.GEMM (CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1 ,
152
- x_scale.data <T>(), weight_i.data <T>(), 1 , d_y->data <T>());
156
+ .broadcast (bcast_for_y)
157
+ .eval ();
158
+ x_scale_mat.device (place) = output_vec_y * x_mat;
159
+ if (d_y) {
160
+ blas.GEMM (CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1 ,
161
+ x_scale.data <T>(), weight_i.data <T>(), 1 , d_y->data <T>());
162
+ }
163
+ if (d_weight) {
164
+ Tensor d_weight_i = d_weight->Slice (i, i + 1 ).Resize (
165
+ framework::make_ddim ({x_dim, y_dim}));
166
+ blas.GEMM (CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1 ,
167
+ x_scale.data <T>(), y->data <T>(), 0 , d_weight_i.data <T>());
168
+ }
153
169
}
154
170
}
155
171
}
156
172
157
- // Caculate the gradient of Input(Weight).
158
- if (d_weight) {
159
- d_weight->mutable_data <T>(ctx.GetPlace ());
160
- Eigen::DSizes<int , 2 > bcast_for_weight (1 , x_dim);
161
- for (int i = 0 ; i < out_dim; ++i) {
162
- Tensor d_weight_i = d_weight->Slice (i, i + 1 ).Resize (
163
- framework::make_ddim ({x_dim, y_dim}));
164
- auto output_vec = d_out_mat.chip (i, 1 );
165
- x_scale_mat.device (place) =
166
- output_vec.reshape (Eigen::DSizes<int , 2 >(batch_size, 1 ))
167
- .broadcast (bcast_for_weight) *
168
- x_mat;
169
- blas.GEMM (CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1 ,
170
- x_scale.data <T>(), y->data <T>(), 0 , d_weight_i.data <T>());
171
- }
172
- }
173
-
174
- // Caculate the gradient of Input(Bias).
173
+ // calculate the gradient of Input(Bias).
175
174
if (d_bias) {
176
175
d_bias->mutable_data <T>(ctx.GetPlace ());
177
176
auto d_bias_mat = framework::EigenVector<T>::Flatten (*d_bias);
0 commit comments