Skip to content

Commit 88c22e9

Browse files
reyoungYang Yang(Tony)
authored andcommitted
Speed up elemwise grad (#8402)
* Speed up elemwise grad * Fix bug * Add macro for MAX_BLOCK_DIM
1 parent d316233 commit 88c22e9

File tree

2 files changed

+259
-57
lines changed

2 files changed

+259
-57
lines changed

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -41,59 +41,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
4141
};
4242

4343
template <typename T>
44-
struct ElementwiseAddGradFunctor {
45-
template <typename Device, typename X, typename Y, typename Z, typename dX,
46-
typename dY, typename dZ>
47-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
48-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
49-
if (dx) {
50-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
51-
dx_e.device(d) = dz_e;
52-
}
53-
if (dy) {
54-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
55-
dy_e.device(d) = dz_e;
56-
}
57-
}
58-
};
59-
60-
template <typename T>
61-
struct ElementwiseAddBroadCastGradFunctor {
62-
template <typename Device, typename X, typename Y, typename Z, typename dX,
63-
typename dY, typename dZ, typename Pre, typename N>
64-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
65-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
66-
if (dx) {
67-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
68-
dx_e.device(d) = dz_e;
69-
}
70-
71-
if (dy) {
72-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
73-
dy_e.device(d) = dz_e.reshape(Eigen::DSizes<int, 2>(pre, n))
74-
.sum(Eigen::array<int, 1>{{0}});
75-
}
76-
}
77-
};
78-
79-
template <typename T>
80-
struct ElementwiseAddBroadCast2GradFunctor {
81-
template <typename Device, typename X, typename Y, typename Z, typename dX,
82-
typename dY, typename dZ, typename Pre, typename N, typename Post>
83-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
84-
Post post) {
85-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
86-
if (dx) {
87-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
88-
dx_e.device(d) = dz_e;
89-
}
90-
91-
if (dy) {
92-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
93-
dy_e.device(d) = dz_e.reshape(Eigen::DSizes<int, 3>(pre, n, post))
94-
.sum(Eigen::array<int, 2>{{0, 2}});
95-
}
96-
}
44+
struct IdentityGrad {
45+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
9746
};
9847

9948
template <typename DeviceContext, typename T>
@@ -109,10 +58,9 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
10958
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
11059
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
11160
int axis = ctx.Attr<int>("axis");
112-
ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>,
113-
ElementwiseAddBroadCastGradFunctor<T>,
114-
ElementwiseAddBroadCast2GradFunctor<T>>(
115-
ctx, x, y, out, dout, axis, dx, dy);
61+
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
62+
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
63+
IdentityGrad<T>());
11664
}
11765
};
11866

paddle/fluid/operators/elementwise_op_function.h

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ limitations under the License. */
2020

2121
#ifdef __NVCC__
2222
#include <thrust/iterator/iterator_adaptor.h>
23+
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
2324
#endif
2425

2526
#include "paddle/fluid/operators/math/math_function.h"
27+
#include "paddle/fluid/platform/for_range.h"
2628

2729
namespace paddle {
2830
namespace operators {
@@ -311,6 +313,258 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
311313
#define EIGEN_DIV(x, y) ((x) / (y))
312314
EIGEN_FUNCTOR(Div, EIGEN_DIV);
313315

316+
template <typename T, typename DX_OP, typename DY_OP>
317+
struct ElemwiseGradNoBroadcast {
318+
const T* x_;
319+
const T* y_;
320+
const T* out_;
321+
const T* dout_;
322+
323+
HOSTDEVICE void operator()(size_t i) {
324+
if (dx_ != nullptr) {
325+
dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
326+
}
327+
if (dy_ != nullptr) {
328+
dy_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
329+
}
330+
}
331+
332+
DX_OP dx_op_;
333+
DY_OP dy_op_;
334+
T* dx_;
335+
T* dy_;
336+
};
337+
338+
template <typename T, typename DX_OP, typename DY_OP>
339+
static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
340+
const T* dout, int h, int w, DX_OP dx_op,
341+
DY_OP dy_op, T* dx, T* dy) {
342+
for (int i = 0; i < h; ++i) {
343+
for (int j = 0; j < w; ++j) {
344+
int x_offset = i * w + j;
345+
if (dx != nullptr) {
346+
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
347+
}
348+
if (dy != nullptr) {
349+
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
350+
if (i == 0) {
351+
dy[j] = tmp;
352+
} else {
353+
dy[j] += tmp;
354+
}
355+
}
356+
}
357+
}
358+
}
359+
#ifdef __NVCC__
360+
template <typename T, typename DX_OP, typename DY_OP>
361+
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
362+
const T* x, const T* y, const T* out, const T* dout, int h, int w,
363+
DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
364+
extern __shared__ char shm_buffer[];
365+
T* shm = reinterpret_cast<T*>(shm_buffer);
366+
367+
int j = blockIdx.x;
368+
int i = threadIdx.x;
369+
int tid = threadIdx.x;
370+
shm[tid] = 0;
371+
372+
do {
373+
int x_offset = i * w + j;
374+
if (dx) {
375+
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
376+
}
377+
if (dy) {
378+
shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
379+
}
380+
i += ELEMWISE_MAX_BLOCK_DIM;
381+
} while (i < h);
382+
383+
if (dy) {
384+
__syncthreads();
385+
386+
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
387+
388+
// Sum, could be optimized
389+
if (threadIdx.x == 0) {
390+
for (int k = 1; k < h; ++k) {
391+
shm[0] += shm[k];
392+
}
393+
dy[j] = shm[0];
394+
}
395+
}
396+
}
397+
398+
template <typename T, typename DX_OP, typename DY_OP>
399+
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
400+
const T* y, const T* out, const T* dout,
401+
int h, int w, DX_OP dx_op, DY_OP dy_op,
402+
T* dx, T* dy) {
403+
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
404+
int gird_size = w;
405+
int shared_mem_size = block_size * sizeof(T);
406+
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, shared_mem_size,
407+
stream>>>(x, y, out, dout, h, w, dx_op,
408+
dy_op, dx, dy);
409+
}
410+
411+
#endif
412+
413+
template <typename T, typename DX_OP, typename DY_OP>
414+
static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
415+
const T* dout, int pre, int n, int post,
416+
DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
417+
for (int i = 0; i < pre; ++i) {
418+
for (int j = 0; j < n; ++j) {
419+
for (int k = 0; k < post; ++k) {
420+
int x_offset = i * n * post + j * post + k;
421+
if (dx != nullptr) {
422+
dx[x_offset] =
423+
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
424+
}
425+
if (dy != nullptr) {
426+
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
427+
if (i == 0 && k == 0) {
428+
dy[j] = tmp;
429+
} else {
430+
dy[j] += tmp;
431+
}
432+
}
433+
}
434+
}
435+
}
436+
}
437+
438+
#ifdef __NVCC__
439+
440+
template <typename T, typename DX_OP, typename DY_OP>
441+
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
442+
const T* x, const T* y, const T* out, const T* dout, int pre, int n,
443+
int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
444+
int tid = threadIdx.x;
445+
int j = blockIdx.x;
446+
447+
extern __shared__ char shm_buffer[];
448+
T* shm = reinterpret_cast<T*>(shm_buffer);
449+
shm[tid] = 0;
450+
int ttid = tid;
451+
452+
while (true) {
453+
int i = ttid / post;
454+
int k = ttid % post;
455+
if (i >= pre) break;
456+
457+
int x_offset = i * n * post + j * post + k;
458+
459+
if (dx != nullptr) {
460+
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
461+
}
462+
463+
if (dy != nullptr) {
464+
shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
465+
}
466+
467+
ttid += ELEMWISE_MAX_BLOCK_DIM;
468+
}
469+
470+
if (dy) {
471+
__syncthreads();
472+
int h = pre * post;
473+
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
474+
475+
// Sum, could be optimized
476+
if (tid == 0) {
477+
for (int i = 1; i < h; ++i) {
478+
shm[0] += shm[i];
479+
}
480+
dy[j] = shm[0];
481+
}
482+
}
483+
}
484+
485+
template <typename T, typename DX_OP, typename DY_OP>
486+
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
487+
const T* y, const T* out, const T* dout,
488+
int pre, int n, int post, DX_OP dx_op,
489+
DY_OP dy_op, T* dx, T* dy) {
490+
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
491+
int gird_size = n;
492+
int shared_mem_size = block_size * sizeof(T);
493+
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, shared_mem_size,
494+
stream>>>(x, y, out, dout, pre, n, post,
495+
dx_op, dy_op, dx, dy);
496+
}
497+
498+
#endif
499+
500+
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
501+
void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
502+
const framework::Tensor& x, const framework::Tensor& y,
503+
const framework::Tensor& out,
504+
const framework::Tensor& dout, int axis,
505+
framework::Tensor* dx, framework::Tensor* dy,
506+
DX_OP dx_op, DY_OP dy_op) {
507+
if (x.dims() == y.dims()) {
508+
size_t N = static_cast<size_t>(framework::product(x.dims()));
509+
platform::ForRange<DeviceContext> for_range(
510+
ctx.template device_context<DeviceContext>(), N);
511+
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
512+
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
513+
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
514+
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
515+
} else { // Y is a scalar
516+
auto x_dim = x.dims();
517+
auto y_dim = y.dims();
518+
519+
if (y_dim.size() == 1 && y_dim[0] == 1) {
520+
// y is a scalar
521+
auto extended_dims = framework::vectorize(x_dim);
522+
extended_dims.push_back(1);
523+
x_dim = framework::make_ddim(extended_dims);
524+
}
525+
526+
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
527+
int pre, n, post;
528+
get_mid_dims(x_dim, y_dim, axis, pre, n, post);
529+
if (post == 1) {
530+
int h = pre;
531+
int w = n;
532+
if (platform::is_gpu_place(ctx.GetPlace())) {
533+
#ifdef __NVCC__
534+
ElemwiseGradBroadcast1CUDA(
535+
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
536+
y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
537+
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
538+
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
539+
#endif
540+
} else {
541+
ElemwiseGradBroadcast1CPU(
542+
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w,
543+
dx_op, dy_op,
544+
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
545+
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
546+
}
547+
} else {
548+
if (platform::is_gpu_place(ctx.GetPlace())) {
549+
#ifdef __NVCC__
550+
ElemwiseGradBroadcast2CUDA(
551+
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
552+
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
553+
dy_op,
554+
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
555+
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
556+
#endif
557+
} else {
558+
ElemwiseGradBroadcast2CPU(
559+
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
560+
post, dx_op, dy_op,
561+
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
562+
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
563+
}
564+
}
565+
}
566+
};
567+
314568
template <typename DeviceContext, typename T, typename functor,
315569
typename broadcastfunctor, typename broadcast2functor>
316570
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,

0 commit comments

Comments
 (0)