Skip to content

Commit fde47aa

Browse files
author
Tomasz Patejko
committed
MKL elementwise add backward: grad inputs copied when they are not null
1 parent 996d12f commit fde47aa

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
102102
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
103103
auto blas = math::GetBlas<DeviceContext, T>(ctx);
104104

105-
if (dx)
106-
dx->mutable_data<T>(ctx.GetPlace());
107-
if (dy)
108-
dy->mutable_data<T>(ctx.GetPlace());
109-
110-
blas.VCOPY(dout->numel(), dout->data<T>(), dx->data<T>());
111-
blas.VCOPY(dout->numel(), dout->data<T>(), dy->data<T>());
105+
if (dx) {
106+
blas.VCOPY(dout->numel(), dout->data<T>(),
107+
dx->mutable_data<T>(ctx.GetPlace()));
108+
}
109+
110+
if (dy) {
111+
blas.VCOPY(dout->numel(), dout->data<T>(),
112+
dy->mutable_data<T>(ctx.GetPlace()));
113+
}
112114
} else {
113115
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
114116
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),

0 commit comments

Comments
 (0)