@@ -147,11 +147,14 @@ struct KronGradElemFunctor {
147
147
index_b += stride_b_[i] * pos_bi;
148
148
}
149
149
150
- size_t index_out_a = index_a * numel_b_ + index_b;
151
- size_t index_out_b = index_b * numel_a_ + index_a;
152
-
153
- dout_a_[index_out_a] = dout_[idx] * B_[index_b];
154
- dout_b_[index_out_b] = dout_[idx] * A_[index_a];
150
+ if (dout_a_) {
151
+ size_t index_out_a = index_a * numel_b_ + index_b;
152
+ dout_a_[index_out_a] = dout_[idx] * B_[index_b];
153
+ }
154
+ if (dout_b_) {
155
+ size_t index_out_b = index_b * numel_a_ + index_a;
156
+ dout_b_[index_out_b] = dout_[idx] * A_[index_a];
157
+ }
155
158
}
156
159
157
160
private:
@@ -222,35 +225,50 @@ struct KronGradOpFunctor {
222
225
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
223
226
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
224
227
framework::Tensor dout_x;
225
- dout_x.mutable_data <T>({numel_x, numel_y}, dev_ctx.GetPlace ());
228
+ T* p_dout_x = nullptr ;
229
+ if (dx) {
230
+ dout_x.mutable_data <T>({numel_x, numel_y}, dev_ctx.GetPlace ());
231
+ p_dout_x = dout_x.data <T>();
232
+ }
226
233
framework::Tensor dout_y;
227
- dout_y.mutable_data <T>({numel_y, numel_x}, dev_ctx.GetPlace ());
234
+ T* p_dout_y = nullptr ;
235
+ if (dy) {
236
+ dout_y.mutable_data <T>({numel_y, numel_x}, dev_ctx.GetPlace ());
237
+ p_dout_y = dout_y.data <T>();
238
+ }
228
239
229
240
platform::ForRange<DeviceContext> for_range (dev_ctx, numel);
230
241
KronGradElemFunctor<T> func (dout.data <T>(), x.data <T>(), y.data <T>(),
231
- dout_x.data <T>(), dout_y.data <T>(),
232
- p_stride_dout, p_stride_x, p_stride_y,
233
- p_shape_y, numel_x, numel_y, ndims);
242
+ p_dout_x, p_dout_y, p_stride_dout, p_stride_x,
243
+ p_stride_y, p_shape_y, numel_x, numel_y, ndims);
234
244
for_range (func);
235
245
236
246
// reduce_sum along aixs 1
237
247
#if __NVCC__
238
248
auto stream = dev_ctx.stream (); // it is a cuda device_context
239
- TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
240
- dout_x, dx, {1 }, static_cast <T>(0 ), cub::Sum (), IdentityFunctor<T>(),
241
- stream);
242
- TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
243
- dout_y, dy, {1 }, static_cast <T>(0 ), cub::Sum (), IdentityFunctor<T>(),
244
- stream);
249
+ if (dx) {
250
+ TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
251
+ dout_x, dx, {1 }, static_cast <T>(0 ), cub::Sum (), IdentityFunctor<T>(),
252
+ stream);
253
+ }
254
+ if (dy) {
255
+ TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
256
+ dout_y, dy, {1 }, static_cast <T>(0 ), cub::Sum (), IdentityFunctor<T>(),
257
+ stream);
258
+ }
245
259
#else
246
- auto eigen_dout_x = framework::EigenMatrix<T>::Reshape (dout_x, 1 );
247
- auto eigen_dout_y = framework::EigenMatrix<T>::Reshape (dout_y, 1 );
248
- auto eigen_vec_dx = framework::EigenVector<T>::Flatten (*dx);
249
- auto eigen_vec_dy = framework::EigenVector<T>::Flatten (*dy);
250
260
auto * place = dev_ctx.eigen_device ();
251
261
Eigen::array<int , 1 > reduce_dim = {1 };
252
- eigen_vec_dx.device (*place) = eigen_dout_x.sum (reduce_dim);
253
- eigen_vec_dy.device (*place) = eigen_dout_y.sum (reduce_dim);
262
+ if (dx) {
263
+ auto eigen_dout_x = framework::EigenMatrix<T>::Reshape (dout_x, 1 );
264
+ auto eigen_vec_dx = framework::EigenVector<T>::Flatten (*dx);
265
+ eigen_vec_dx.device (*place) = eigen_dout_x.sum (reduce_dim);
266
+ }
267
+ if (dy) {
268
+ auto eigen_dout_y = framework::EigenMatrix<T>::Reshape (dout_y, 1 );
269
+ auto eigen_vec_dy = framework::EigenVector<T>::Flatten (*dy);
270
+ eigen_vec_dy.device (*place) = eigen_dout_y.sum (reduce_dim);
271
+ }
254
272
#endif
255
273
}
256
274
};
@@ -307,17 +325,33 @@ class KronGradKernel : public framework::OpKernel<T> {
307
325
308
326
auto * dx = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
309
327
auto * dy = ctx.Output <framework::Tensor>(framework::GradVarName (" Y" ));
310
- dx->mutable_data <T>(ctx.GetPlace ());
311
- dy->mutable_data <T>(ctx.GetPlace ());
328
+ if (dx) {
329
+ dx->mutable_data <T>(ctx.GetPlace ());
330
+ }
331
+ if (dy) {
332
+ dy->mutable_data <T>(ctx.GetPlace ());
333
+ }
312
334
313
335
int ndims = dout->dims ().size ();
314
336
framework::Tensor xx = UnsqueezeTo (*x, ndims);
315
- framework::Tensor dxx = UnsqueezeTo (*dx, ndims);
316
337
framework::Tensor yy = UnsqueezeTo (*y, ndims);
317
- framework::Tensor dyy = UnsqueezeTo (*dy, ndims);
338
+
339
+ framework::Tensor* pdxx = nullptr ;
340
+ framework::Tensor* pdyy = nullptr ;
341
+ framework::Tensor dxx;
342
+ framework::Tensor dyy;
343
+ if (dx) {
344
+ dxx = UnsqueezeTo (*dx, ndims);
345
+ pdxx = &dxx;
346
+ }
347
+
348
+ if (dy) {
349
+ dyy = UnsqueezeTo (*dy, ndims);
350
+ pdyy = &dyy;
351
+ }
318
352
319
353
KronGradOpFunctor<DeviceContext, T> func;
320
- func (dev_ctx, *dout, xx, yy, &dxx, &dyy );
354
+ func (dev_ctx, *dout, xx, yy, pdxx, pdyy );
321
355
}
322
356
};
323
357
0 commit comments