Skip to content

Commit bab1196

Browse files
authored
Merge pull request #10913 from tpatejko/tpatejko/optimized-elementwise-add
Blas optimized elementwise_add forward and backward passes
2 parents 4d29a5d + 3e876b3 commit bab1196

File tree

3 files changed

+160
-10
lines changed

3 files changed

+160
-10
lines changed

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 102 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include "paddle/fluid/framework/eigen.h"
1718
#include "paddle/fluid/operators/elementwise_op_function.h"
19+
#include "paddle/fluid/operators/math/blas.h"
1820

1921
namespace paddle {
2022
namespace operators {
@@ -24,19 +26,57 @@ struct AddFunctor {
2426
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
2527
};
2628

29+
template <typename DeviceContext, typename T>
30+
void default_elementwise_add(const framework::ExecutionContext& ctx,
31+
const framework::Tensor* x,
32+
const framework::Tensor* y, framework::Tensor* z) {
33+
int axis = ctx.Attr<int>("axis");
34+
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
35+
AddFunctor<T>(), z);
36+
}
37+
38+
template <typename DeviceContext, typename T>
39+
typename std::enable_if<
40+
std::is_floating_point<T>::value &&
41+
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
42+
elementwise_add(const framework::ExecutionContext& ctx,
43+
const framework::Tensor* x, const framework::Tensor* y,
44+
framework::Tensor* z) {
45+
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
46+
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
47+
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
48+
49+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
50+
blas.VADD(x->numel(), eigen_x.data(), eigen_y.data(), eigen_z.data());
51+
}
52+
53+
template <typename DeviceContext, typename T>
54+
typename std::enable_if<
55+
!std::is_floating_point<T>::value ||
56+
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
57+
elementwise_add(const framework::ExecutionContext& ctx,
58+
const framework::Tensor* x, const framework::Tensor* y,
59+
framework::Tensor* z) {
60+
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
61+
}
62+
2763
template <typename DeviceContext, typename T>
2864
class ElementwiseAddKernel : public framework::OpKernel<T> {
2965
public:
3066
void Compute(const framework::ExecutionContext& ctx) const override {
3167
using Tensor = framework::Tensor;
3268

33-
auto* x = ctx.Input<Tensor>("X");
34-
auto* y = ctx.Input<Tensor>("Y");
35-
auto* z = ctx.Output<Tensor>("Out");
69+
const auto x = ctx.Input<Tensor>("X");
70+
const auto y = ctx.Input<Tensor>("Y");
71+
auto z = ctx.Output<Tensor>("Out");
3672
z->mutable_data<T>(ctx.GetPlace());
37-
int axis = ctx.Attr<int>("axis");
38-
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
39-
AddFunctor<T>(), z);
73+
74+
auto dims_equal = x->dims() == y->dims();
75+
if (dims_equal) {
76+
elementwise_add<DeviceContext, T>(ctx, x, y, z);
77+
} else {
78+
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
79+
}
4080
}
4181
};
4282

@@ -45,6 +85,55 @@ struct IdentityGrad {
4585
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
4686
};
4787

88+
template <typename DeviceContext, typename T>
89+
void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
90+
const framework::Tensor* x,
91+
const framework::Tensor* y,
92+
const framework::Tensor* out,
93+
const framework::Tensor* dout,
94+
framework::Tensor* dx,
95+
framework::Tensor* dy) {
96+
int axis = ctx.Attr<int>("axis");
97+
98+
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
99+
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
100+
IdentityGrad<T>());
101+
}
102+
103+
template <typename DeviceContext, typename T>
104+
typename std::enable_if<
105+
std::is_floating_point<T>::value &&
106+
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
107+
elementwise_add_grad(const framework::ExecutionContext& ctx,
108+
const framework::Tensor* x, const framework::Tensor* y,
109+
const framework::Tensor* out,
110+
const framework::Tensor* dout, framework::Tensor* dx,
111+
framework::Tensor* dy) {
112+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
113+
114+
if (dx) {
115+
blas.VCOPY(dout->numel(), dout->data<T>(),
116+
dx->mutable_data<T>(ctx.GetPlace()));
117+
}
118+
119+
if (dy) {
120+
blas.VCOPY(dout->numel(), dout->data<T>(),
121+
dy->mutable_data<T>(ctx.GetPlace()));
122+
}
123+
}
124+
125+
template <typename DeviceContext, typename T>
126+
typename std::enable_if<
127+
!std::is_floating_point<T>::value ||
128+
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
129+
elementwise_add_grad(const framework::ExecutionContext& ctx,
130+
const framework::Tensor* x, const framework::Tensor* y,
131+
const framework::Tensor* out,
132+
const framework::Tensor* dout, framework::Tensor* dx,
133+
framework::Tensor* dy) {
134+
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
135+
}
136+
48137
template <typename DeviceContext, typename T>
49138
class ElementwiseAddGradKernel : public framework::OpKernel<T> {
50139
public:
@@ -57,10 +146,13 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
57146
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
58147
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
59148
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
60-
int axis = ctx.Attr<int>("axis");
61-
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
62-
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
63-
IdentityGrad<T>());
149+
150+
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
151+
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
152+
} else {
153+
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
154+
dy);
155+
}
64156
}
65157
};
66158

paddle/fluid/operators/math/blas.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ class Blas {
125125
template <typename T>
126126
void AXPY(int n, T alpha, const T* x, T* y) const;
127127

128+
template <typename T>
129+
void VADD(int n, const T* x, const T* y, T* z) const;
130+
131+
template <typename T>
132+
void VCOPY(int n, const T* x, T* y) const;
133+
128134
template <typename T>
129135
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
130136
T* C) const;
@@ -163,6 +169,16 @@ class BlasT : private Blas<DeviceContext> {
163169
Base()->template AXPY<T>(args...);
164170
}
165171

172+
template <typename... ARGS>
173+
void VADD(ARGS... args) const {
174+
Base()->template VADD<T>(args...);
175+
}
176+
177+
template <typename... ARGS>
178+
void VCOPY(ARGS... args) const {
179+
Base()->template VCOPY<T>(args...);
180+
}
181+
166182
template <typename... ARGS>
167183
void GEMV(ARGS... args) const {
168184
Base()->template GEMV<T>(args...);

paddle/fluid/operators/math/blas_impl.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ struct CBlas<float> {
3434
cblas_saxpy(args...);
3535
}
3636

37+
#ifdef PADDLE_WITH_MKLML
38+
template <typename... ARGS>
39+
static void VADD(ARGS... args) {
40+
vsAdd(args...);
41+
}
42+
#endif
43+
44+
template <typename... ARGS>
45+
static void VCOPY(ARGS... args) {
46+
cblas_scopy(args...);
47+
}
48+
3749
template <typename... ARGS>
3850
static void GEMV(ARGS... args) {
3951
cblas_sgemv(args...);
@@ -59,6 +71,18 @@ struct CBlas<double> {
5971
cblas_daxpy(args...);
6072
}
6173

74+
#ifdef PADDLE_WITH_MKLML
75+
template <typename... ARGS>
76+
static void VADD(ARGS... args) {
77+
vdAdd(args...);
78+
}
79+
#endif
80+
81+
template <typename... ARGS>
82+
static void VCOPY(ARGS... args) {
83+
cblas_dcopy(args...);
84+
}
85+
6286
template <typename... ARGS>
6387
static void GEMV(ARGS... args) {
6488
cblas_dgemv(args...);
@@ -139,6 +163,24 @@ void Blas<platform::CPUDeviceContext>::AXPY(int n, T alpha, const T *x,
139163
CBlas<T>::AXPY(n, alpha, x, 1, y, 1);
140164
}
141165

166+
template <>
167+
template <typename T>
168+
void Blas<platform::CPUDeviceContext>::VCOPY(int n, const T *x, T *y) const {
169+
CBlas<T>::VCOPY(n, x, 1, y, 1);
170+
}
171+
172+
template <>
173+
template <typename T>
174+
void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
175+
T *z) const {
176+
#ifdef PADDLE_WITH_MKLML
177+
CBlas<T>::VADD(n, x, y, z);
178+
#else
179+
this->template VCOPY<T>(n, y, z);
180+
this->template AXPY<T>(n, 1., x, z);
181+
#endif
182+
}
183+
142184
template <>
143185
template <typename T>
144186
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,

0 commit comments

Comments
 (0)