Skip to content

Commit 5a622c2

Browse files
author
Tomasz Patejko
committed
MKL elementwise add backward: Initial implementation with vector copy
1 parent 01fb2be commit 5a622c2

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

paddle/fluid/operators/elementwise_add_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ REGISTER_OP_CPU_KERNEL(
2525
REGISTER_OP_CPU_KERNEL(
2626
elementwise_add_grad,
2727
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
28-
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>,
29-
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
30-
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
28+
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>);
29+
// ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
30+
// ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,22 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
9898
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
9999
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
100100
int axis = ctx.Attr<int>("axis");
101-
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
102-
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
103-
IdentityGrad<T>());
101+
102+
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
103+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
104+
105+
if (dx)
106+
dx->mutable_data<T>(ctx.GetPlace());
107+
if (dy)
108+
dy->mutable_data<T>(ctx.GetPlace());
109+
110+
blas.VCOPY(dout->numel(), dout->data<T>(), dx->data<T>());
111+
blas.VCOPY(dout->numel(), dout->data<T>(), dy->data<T>());
112+
} else {
113+
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
114+
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
115+
IdentityGrad<T>());
116+
}
104117
}
105118
};
106119

0 commit comments

Comments
 (0)