@@ -110,7 +110,7 @@ class GRUUnitKernel : public framework::OpKernel<T> {
110
110
auto c = g.slice (c_offsets, extents); // output candidate
111
111
112
112
// calculate final output
113
- h.device (place) = u * (h_p - c ) + c ;
113
+ h.device (place) = u * (c - h_p ) + h_p ;
114
114
}
115
115
};
116
116
@@ -146,35 +146,27 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
146
146
auto * weight_grad =
147
147
context.Output <Tensor>(framework::GradVarName (" Weight" ));
148
148
auto * bias_grad = context.Output <Tensor>(framework::GradVarName (" Bias" ));
149
- input_grad->mutable_data <T>(context.GetPlace ());
150
- hidden_prev_grad->mutable_data <T>(context.GetPlace ());
151
- weight_grad->mutable_data <T>(context.GetPlace ());
152
149
Tensor gate_grad;
153
- gate_grad.mutable_data <T>(input->dims (), context.GetPlace ());
154
150
Tensor reset_hidden_prev_grad;
155
- reset_hidden_prev_grad.mutable_data <T>(reset_hidden_prev->dims (),
156
- context.GetPlace ());
157
-
158
- int batch_size = input->dims ()[0 ];
159
- int frame_size = hidden_prev->dims ()[1 ];
160
151
161
152
const T* hidden_prev_data = hidden_prev->data <T>();
162
- T* hidden_prev_grad_data = hidden_prev_grad->data <T>();
163
153
const T* weight_data = weight->data <T>();
164
- T* weight_grad_data = weight_grad-> data <T>();
165
- T* gate_grad_data = gate_grad.data <T>();
154
+ T* gate_grad_data =
155
+ gate_grad.mutable_data <T>(input-> dims (), context. GetPlace () );
166
156
const T* reset_hidden_prev_data = reset_hidden_prev->data <T>();
167
- T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data <T>();
157
+ T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.mutable_data <T>(
158
+ reset_hidden_prev->dims (), context.GetPlace ());
168
159
169
160
auto h_p = EigenMatrix<T>::From (*hidden_prev);
170
161
auto g = EigenMatrix<T>::From (*gate);
171
162
auto d_h = EigenMatrix<T>::From (*hidden_grad);
172
- auto d_x = EigenMatrix<T>::From (*input_grad);
173
- auto d_h_p = EigenMatrix<T>::From (*hidden_prev_grad);
174
163
auto d_g = EigenMatrix<T>::From (gate_grad);
175
164
auto d_r_h_p = EigenMatrix<T>::From (reset_hidden_prev_grad);
176
165
auto place = context.GetEigenDevice <Place>();
177
166
167
+ int batch_size = input->dims ()[0 ];
168
+ int frame_size = hidden_prev->dims ()[1 ];
169
+
178
170
Eigen::array<int , 2 > extents ({{batch_size, frame_size}});
179
171
Eigen::array<int , 2 > u_offsets ({{0 , 0 }});
180
172
auto u = g.slice (u_offsets, extents); // update gate
@@ -185,38 +177,52 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
185
177
186
178
// backward for unactivated update gate
187
179
ActGradCompute (context.Attr <int >(" gate_activation" ), place, u, u,
188
- d_g.slice (u_offsets, extents), d_h * (h_p - c ));
180
+ d_g.slice (u_offsets, extents), d_h * (c - h_p ));
189
181
// backward for unactivated output candidate
190
182
ActGradCompute (context.Attr <int >(" activation" ), place, c, c,
191
- d_g.slice (c_offsets, extents), d_h * (u. constant ( T ( 1 )) - u) );
183
+ d_g.slice (c_offsets, extents), d_h * u );
192
184
// backward for reset_hidden_prev
193
185
math::gemm<Place, T>(context.device_context (), false , true , batch_size,
194
186
frame_size, frame_size, 1 ,
195
187
gate_grad_data + frame_size * 2 , frame_size * 3 ,
196
188
weight_data + frame_size * frame_size * 2 , frame_size,
197
189
0 , reset_hidden_prev_grad_data, frame_size);
198
- // backward for state_weight
199
- math::gemm<Place, T>(
200
- context.device_context (), true , false , frame_size, frame_size,
201
- batch_size, 1 , reset_hidden_prev_data, frame_size,
202
- gate_grad_data + frame_size * 2 , frame_size * 3 , 0 ,
203
- weight_grad_data + frame_size * frame_size * 2 , frame_size);
204
190
// backward for unactivated reset gate
205
191
ActGradCompute (context.Attr <int >(" gate_activation" ), place, r, r,
206
192
d_g.slice (r_offsets, extents), d_r_h_p * h_p);
207
- // backward for update_gate_weight and reset_gate_weight
208
- math::gemm<Place, T>(context.device_context (), true , false , frame_size,
209
- frame_size * 2 , batch_size, 1 , hidden_prev_data,
210
- frame_size, gate_grad_data, frame_size * 3 , 0 ,
211
- weight_grad_data, frame_size * 2 );
193
+ // backward for weight
194
+ if (weight_grad) {
195
+ T* weight_grad_data = weight_grad->mutable_data <T>(context.GetPlace ());
196
+ // backward for state_weight
197
+ math::gemm<Place, T>(
198
+ context.device_context (), true , false , frame_size, frame_size,
199
+ batch_size, 1 , reset_hidden_prev_data, frame_size,
200
+ gate_grad_data + frame_size * 2 , frame_size * 3 , 0 ,
201
+ weight_grad_data + frame_size * frame_size * 2 , frame_size);
202
+
203
+ // backward for update_gate_weight and reset_gate_weight
204
+ math::gemm<Place, T>(context.device_context (), true , false , frame_size,
205
+ frame_size * 2 , batch_size, 1 , hidden_prev_data,
206
+ frame_size, gate_grad_data, frame_size * 3 , 0 ,
207
+ weight_grad_data, frame_size * 2 );
208
+ }
212
209
// backward for hidden_prev
213
- d_h_p.device (place) = d_r_h_p * r + d_h * u;
214
- math::gemm<Place, T>(context.device_context (), false , true , batch_size,
215
- frame_size, frame_size * 2 , 1 , gate_grad_data,
216
- frame_size * 3 , weight_data, frame_size * 2 , 1 ,
217
- hidden_prev_grad_data, frame_size);
210
+ if (hidden_prev_grad) {
211
+ T* hidden_prev_grad_data =
212
+ hidden_prev_grad->mutable_data <T>(context.GetPlace ());
213
+ auto d_h_p = EigenMatrix<T>::From (*hidden_prev_grad);
214
+ d_h_p.device (place) = d_r_h_p * r + d_h * (u.constant (T (1 )) - u);
215
+ math::gemm<Place, T>(context.device_context (), false , true , batch_size,
216
+ frame_size, frame_size * 2 , 1 , gate_grad_data,
217
+ frame_size * 3 , weight_data, frame_size * 2 , 1 ,
218
+ hidden_prev_grad_data, frame_size);
219
+ }
218
220
// backward for input
219
- d_x.device (place) = d_g;
221
+ if (input_grad) {
222
+ input_grad->mutable_data <T>(context.GetPlace ());
223
+ auto d_x = EigenMatrix<T>::From (*input_grad);
224
+ d_x.device (place) = d_g;
225
+ }
220
226
// backward for bias
221
227
if (bias_grad) {
222
228
bias_grad->mutable_data <T>(context.GetPlace ());
0 commit comments