@@ -20,9 +20,11 @@ limitations under the License. */
20
20
21
21
#ifdef __NVCC__
22
22
#include < thrust/iterator/iterator_adaptor.h>
23
+ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024 ;
23
24
#endif
24
25
25
26
#include " paddle/fluid/operators/math/math_function.h"
27
+ #include " paddle/fluid/platform/for_range.h"
26
28
27
29
namespace paddle {
28
30
namespace operators {
@@ -311,6 +313,258 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
311
313
#define EIGEN_DIV (x, y ) ((x) / (y))
312
314
EIGEN_FUNCTOR (Div, EIGEN_DIV);
313
315
316
+ template <typename T, typename DX_OP, typename DY_OP>
317
+ struct ElemwiseGradNoBroadcast {
318
+ const T* x_;
319
+ const T* y_;
320
+ const T* out_;
321
+ const T* dout_;
322
+
323
+ HOSTDEVICE void operator ()(size_t i) {
324
+ if (dx_ != nullptr ) {
325
+ dx_[i] = dx_op_ (x_[i], y_[i], out_[i], dout_[i]);
326
+ }
327
+ if (dy_ != nullptr ) {
328
+ dy_[i] = dx_op_ (x_[i], y_[i], out_[i], dout_[i]);
329
+ }
330
+ }
331
+
332
+ DX_OP dx_op_;
333
+ DY_OP dy_op_;
334
+ T* dx_;
335
+ T* dy_;
336
+ };
337
+
338
+ template <typename T, typename DX_OP, typename DY_OP>
339
+ static void ElemwiseGradBroadcast1CPU (const T* x, const T* y, const T* out,
340
+ const T* dout, int h, int w, DX_OP dx_op,
341
+ DY_OP dy_op, T* dx, T* dy) {
342
+ for (int i = 0 ; i < h; ++i) {
343
+ for (int j = 0 ; j < w; ++j) {
344
+ int x_offset = i * w + j;
345
+ if (dx != nullptr ) {
346
+ dx[x_offset] = dx_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
347
+ }
348
+ if (dy != nullptr ) {
349
+ T tmp = dy_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
350
+ if (i == 0 ) {
351
+ dy[j] = tmp;
352
+ } else {
353
+ dy[j] += tmp;
354
+ }
355
+ }
356
+ }
357
+ }
358
+ }
359
+ #ifdef __NVCC__
360
+ template <typename T, typename DX_OP, typename DY_OP>
361
+ static __global__ void ElemwiseGradBroadcast1CUDAKernel (
362
+ const T* x, const T* y, const T* out, const T* dout, int h, int w,
363
+ DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
364
+ extern __shared__ char shm_buffer[];
365
+ T* shm = reinterpret_cast <T*>(shm_buffer);
366
+
367
+ int j = blockIdx.x ;
368
+ int i = threadIdx.x ;
369
+ int tid = threadIdx.x ;
370
+ shm[tid] = 0 ;
371
+
372
+ do {
373
+ int x_offset = i * w + j;
374
+ if (dx) {
375
+ dx[x_offset] = dx_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
376
+ }
377
+ if (dy) {
378
+ shm[tid] += dy_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
379
+ }
380
+ i += ELEMWISE_MAX_BLOCK_DIM;
381
+ } while (i < h);
382
+
383
+ if (dy) {
384
+ __syncthreads ();
385
+
386
+ h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
387
+
388
+ // Sum, could be optimized
389
+ if (threadIdx.x == 0 ) {
390
+ for (int k = 1 ; k < h; ++k) {
391
+ shm[0 ] += shm[k];
392
+ }
393
+ dy[j] = shm[0 ];
394
+ }
395
+ }
396
+ }
397
+
398
+ template <typename T, typename DX_OP, typename DY_OP>
399
+ static void ElemwiseGradBroadcast1CUDA (cudaStream_t stream, const T* x,
400
+ const T* y, const T* out, const T* dout,
401
+ int h, int w, DX_OP dx_op, DY_OP dy_op,
402
+ T* dx, T* dy) {
403
+ int block_size = std::min (ELEMWISE_MAX_BLOCK_DIM, h);
404
+ int gird_size = w;
405
+ int shared_mem_size = block_size * sizeof (T);
406
+ ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, shared_mem_size,
407
+ stream>>>(x, y, out, dout, h, w, dx_op,
408
+ dy_op, dx, dy);
409
+ }
410
+
411
+ #endif
412
+
413
+ template <typename T, typename DX_OP, typename DY_OP>
414
+ static void ElemwiseGradBroadcast2CPU (const T* x, const T* y, const T* out,
415
+ const T* dout, int pre, int n, int post,
416
+ DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
417
+ for (int i = 0 ; i < pre; ++i) {
418
+ for (int j = 0 ; j < n; ++j) {
419
+ for (int k = 0 ; k < post; ++k) {
420
+ int x_offset = i * n * post + j * post + k;
421
+ if (dx != nullptr ) {
422
+ dx[x_offset] =
423
+ dx_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
424
+ }
425
+ if (dy != nullptr ) {
426
+ T tmp = dy_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
427
+ if (i == 0 && k == 0 ) {
428
+ dy[j] = tmp;
429
+ } else {
430
+ dy[j] += tmp;
431
+ }
432
+ }
433
+ }
434
+ }
435
+ }
436
+ }
437
+
438
+ #ifdef __NVCC__
439
+
440
+ template <typename T, typename DX_OP, typename DY_OP>
441
+ static __global__ void ElemwiseGradBroadcast2CUDAKernel (
442
+ const T* x, const T* y, const T* out, const T* dout, int pre, int n,
443
+ int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
444
+ int tid = threadIdx.x ;
445
+ int j = blockIdx.x ;
446
+
447
+ extern __shared__ char shm_buffer[];
448
+ T* shm = reinterpret_cast <T*>(shm_buffer);
449
+ shm[tid] = 0 ;
450
+ int ttid = tid;
451
+
452
+ while (true ) {
453
+ int i = ttid / post;
454
+ int k = ttid % post;
455
+ if (i >= pre) break ;
456
+
457
+ int x_offset = i * n * post + j * post + k;
458
+
459
+ if (dx != nullptr ) {
460
+ dx[x_offset] = dx_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
461
+ }
462
+
463
+ if (dy != nullptr ) {
464
+ shm[tid] += dy_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
465
+ }
466
+
467
+ ttid += ELEMWISE_MAX_BLOCK_DIM;
468
+ }
469
+
470
+ if (dy) {
471
+ __syncthreads ();
472
+ int h = pre * post;
473
+ h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
474
+
475
+ // Sum, could be optimized
476
+ if (tid == 0 ) {
477
+ for (int i = 1 ; i < h; ++i) {
478
+ shm[0 ] += shm[i];
479
+ }
480
+ dy[j] = shm[0 ];
481
+ }
482
+ }
483
+ }
484
+
485
+ template <typename T, typename DX_OP, typename DY_OP>
486
+ static void ElemwiseGradBroadcast2CUDA (cudaStream_t stream, const T* x,
487
+ const T* y, const T* out, const T* dout,
488
+ int pre, int n, int post, DX_OP dx_op,
489
+ DY_OP dy_op, T* dx, T* dy) {
490
+ int block_size = std::min (ELEMWISE_MAX_BLOCK_DIM, pre * post);
491
+ int gird_size = n;
492
+ int shared_mem_size = block_size * sizeof (T);
493
+ ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, shared_mem_size,
494
+ stream>>>(x, y, out, dout, pre, n, post,
495
+ dx_op, dy_op, dx, dy);
496
+ }
497
+
498
+ #endif
499
+
500
+ template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
501
+ void ElemwiseGradCompute (const framework::ExecutionContext& ctx,
502
+ const framework::Tensor& x, const framework::Tensor& y,
503
+ const framework::Tensor& out,
504
+ const framework::Tensor& dout, int axis,
505
+ framework::Tensor* dx, framework::Tensor* dy,
506
+ DX_OP dx_op, DY_OP dy_op) {
507
+ if (x.dims () == y.dims ()) {
508
+ size_t N = static_cast <size_t >(framework::product (x.dims ()));
509
+ platform::ForRange<DeviceContext> for_range (
510
+ ctx.template device_context <DeviceContext>(), N);
511
+ for_range (ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
512
+ x.data <T>(), y.data <T>(), out.data <T>(), dout.data <T>(), dx_op, dy_op,
513
+ dx == nullptr ? nullptr : dx->mutable_data <T>(ctx.GetPlace ()),
514
+ dy == nullptr ? nullptr : dy->mutable_data <T>(ctx.GetPlace ())});
515
+ } else { // Y is a scalar
516
+ auto x_dim = x.dims ();
517
+ auto y_dim = y.dims ();
518
+
519
+ if (y_dim.size () == 1 && y_dim[0 ] == 1 ) {
520
+ // y is a scalar
521
+ auto extended_dims = framework::vectorize (x_dim);
522
+ extended_dims.push_back (1 );
523
+ x_dim = framework::make_ddim (extended_dims);
524
+ }
525
+
526
+ axis = (axis == -1 ? x_dim.size () - y_dim.size () : axis);
527
+ int pre, n, post;
528
+ get_mid_dims (x_dim, y_dim, axis, pre, n, post);
529
+ if (post == 1 ) {
530
+ int h = pre;
531
+ int w = n;
532
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
533
+ #ifdef __NVCC__
534
+ ElemwiseGradBroadcast1CUDA (
535
+ ctx.template device_context <DeviceContext>().stream (), x.data <T>(),
536
+ y.data <T>(), out.data <T>(), dout.data <T>(), h, w, dx_op, dy_op,
537
+ dx == nullptr ? nullptr : dx->mutable_data <T>(ctx.GetPlace ()),
538
+ dy == nullptr ? nullptr : dy->mutable_data <T>(ctx.GetPlace ()));
539
+ #endif
540
+ } else {
541
+ ElemwiseGradBroadcast1CPU (
542
+ x.data <T>(), y.data <T>(), out.data <T>(), dout.data <T>(), h, w,
543
+ dx_op, dy_op,
544
+ dx == nullptr ? nullptr : dx->mutable_data <T>(ctx.GetPlace ()),
545
+ dy == nullptr ? nullptr : dy->mutable_data <T>(ctx.GetPlace ()));
546
+ }
547
+ } else {
548
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
549
+ #ifdef __NVCC__
550
+ ElemwiseGradBroadcast2CUDA (
551
+ ctx.template device_context <DeviceContext>().stream (), x.data <T>(),
552
+ y.data <T>(), out.data <T>(), dout.data <T>(), pre, n, post, dx_op,
553
+ dy_op,
554
+ dx == nullptr ? nullptr : dx->mutable_data <T>(ctx.GetPlace ()),
555
+ dy == nullptr ? nullptr : dy->mutable_data <T>(ctx.GetPlace ()));
556
+ #endif
557
+ } else {
558
+ ElemwiseGradBroadcast2CPU (
559
+ x.data <T>(), y.data <T>(), out.data <T>(), dout.data <T>(), pre, n,
560
+ post, dx_op, dy_op,
561
+ dx == nullptr ? nullptr : dx->mutable_data <T>(ctx.GetPlace ()),
562
+ dy == nullptr ? nullptr : dy->mutable_data <T>(ctx.GetPlace ()));
563
+ }
564
+ }
565
+ }
566
+ };
567
+
314
568
template <typename DeviceContext, typename T, typename functor,
315
569
typename broadcastfunctor, typename broadcast2functor>
316
570
void ElementwiseGradCompute (const framework::ExecutionContext& ctx,
0 commit comments