Skip to content

Commit 9241011

Browse files
author
Tomasz Patejko
committed
MKL elementwise add backward: backward works for integral types with fall back to default impl
1 parent fde47aa commit 9241011

File tree

2 files changed

+57
-18
lines changed

2 files changed

+57
-18
lines changed

paddle/fluid/operators/elementwise_add_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ REGISTER_OP_CPU_KERNEL(
2525
REGISTER_OP_CPU_KERNEL(
2626
elementwise_add_grad,
2727
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
28-
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>);
29-
// ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
30-
// ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
28+
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>,
29+
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
30+
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,57 @@ struct IdentityGrad {
8585
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
8686
};
8787

88+
template<typename DeviceContext, typename T>
89+
void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
90+
const framework::Tensor* x,
91+
const framework::Tensor* y,
92+
const framework::Tensor* out,
93+
const framework::Tensor* dout,
94+
framework::Tensor* dx,
95+
framework::Tensor* dy) {
96+
int axis = ctx.Attr<int>("axis");
97+
98+
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
99+
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
100+
IdentityGrad<T>());
101+
}
102+
103+
template<typename DeviceContext, typename T>
104+
typename std::enable_if<
105+
std::is_floating_point<T>::value &&
106+
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
107+
elementwise_add_grad(const framework::ExecutionContext& ctx,
108+
const framework::Tensor* x,
109+
const framework::Tensor* y,
110+
const framework::Tensor* out,
111+
const framework::Tensor* dout,
112+
framework::Tensor* dx, framework::Tensor* dy) {
113+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
114+
115+
if (dx) {
116+
blas.VCOPY(dout->numel(), dout->data<T>(),
117+
dx->mutable_data<T>(ctx.GetPlace()));
118+
}
119+
120+
if (dy) {
121+
blas.VCOPY(dout->numel(), dout->data<T>(),
122+
dy->mutable_data<T>(ctx.GetPlace()));
123+
}
124+
}
125+
126+
template<typename DeviceContext, typename T>
127+
typename std::enable_if<
128+
!std::is_floating_point<T>::value ||
129+
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
130+
elementwise_add_grad(const framework::ExecutionContext& ctx,
131+
const framework::Tensor* x,
132+
const framework::Tensor* y,
133+
const framework::Tensor* out,
134+
const framework::Tensor* dout,
135+
framework::Tensor* dx, framework::Tensor* dy) {
136+
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
137+
}
138+
88139
template <typename DeviceContext, typename T>
89140
class ElementwiseAddGradKernel : public framework::OpKernel<T> {
90141
public:
@@ -97,24 +148,12 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
97148
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
98149
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
99150
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
100-
int axis = ctx.Attr<int>("axis");
101151

102152
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
103-
auto blas = math::GetBlas<DeviceContext, T>(ctx);
104-
105-
if (dx) {
106-
blas.VCOPY(dout->numel(), dout->data<T>(),
107-
dx->mutable_data<T>(ctx.GetPlace()));
108-
}
109-
110-
if (dy) {
111-
blas.VCOPY(dout->numel(), dout->data<T>(),
112-
dy->mutable_data<T>(ctx.GetPlace()));
113-
}
153+
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
114154
} else {
115-
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
116-
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
117-
IdentityGrad<T>());
155+
default_elementwise_add_grad<DeviceContext, T>(
156+
ctx, x, y, out, dout, dx, dy);
118157
}
119158
}
120159
};

0 commit comments

Comments
 (0)