@@ -23,6 +23,12 @@ void CastKernel(const Context& dev_ctx,
23
23
phi::DataType dtype,
24
24
phi::DenseTensor* out);
25
25
26
+ template <typename T, typename Context>
27
+ void TransposeKernel (const Context& dev_ctx,
28
+ const phi::DenseTensor& x,
29
+ const std::vector<int >& axis,
30
+ phi::DenseTensor* out);
31
+
26
32
template <typename T = int >
27
33
inline void UpdatePadding (std::vector<T>* paddings,
28
34
const bool global_pooling,
@@ -233,21 +239,21 @@ void Pool2dKernel(const Context& dev_ctx,
233
239
}
234
240
235
241
template <typename T, typename Context>
236
- void Pool2dGradKernel (const Context& dev_ctx,
237
- const phi::DenseTensor& in_x,
238
- const phi::DenseTensor& out,
239
- const phi::DenseTensor& out_grad,
240
- const phi::IntArray& kernel_size,
241
- const std::vector<int >& strides_t ,
242
- const std::vector<int >& paddings_t ,
243
- bool ceil_mode,
244
- bool exclusive,
245
- const std::string& data_format,
246
- const std::string& pooling_type,
247
- bool global_pooling,
248
- bool adaptive,
249
- const std::string& padding_algorithm,
250
- phi::DenseTensor* in_x_grad) {
242
+ void AclopPool2dGradKernel (const Context& dev_ctx,
243
+ const phi::DenseTensor& in_x,
244
+ const phi::DenseTensor& out,
245
+ const phi::DenseTensor& out_grad,
246
+ const phi::IntArray& kernel_size,
247
+ const std::vector<int >& strides_t ,
248
+ const std::vector<int >& paddings_t ,
249
+ bool ceil_mode,
250
+ bool exclusive,
251
+ const std::string& data_format,
252
+ const std::string& pooling_type,
253
+ bool global_pooling,
254
+ bool adaptive,
255
+ const std::string& padding_algorithm,
256
+ phi::DenseTensor* in_x_grad) {
251
257
dev_ctx.template Alloc <T>(in_x_grad);
252
258
253
259
std::vector<int > ksize (kernel_size.GetData ().begin (),
@@ -451,6 +457,200 @@ void Pool2dGradKernel(const Context& dev_ctx,
451
457
}
452
458
}
453
459
460
+ template <typename T, typename Context>
461
+ void Pool2dGradKernel (const Context& dev_ctx,
462
+ const phi::DenseTensor& in_x,
463
+ const phi::DenseTensor& out,
464
+ const phi::DenseTensor& out_grad,
465
+ const phi::IntArray& kernel_size,
466
+ const std::vector<int >& strides_t ,
467
+ const std::vector<int >& paddings_t ,
468
+ bool ceil_mode,
469
+ bool exclusive,
470
+ const std::string& data_format,
471
+ const std::string& pooling_type,
472
+ bool global_pooling,
473
+ bool adaptive,
474
+ const std::string& padding_algorithm,
475
+ phi::DenseTensor* in_x_grad) {
476
+ DO_COMPATIBILITY (
477
+ aclnnAvgPool2dBackward,
478
+ (custom_kernel::AclopPool2dGradKernel<T, Context>(dev_ctx,
479
+ in_x,
480
+ out,
481
+ out_grad,
482
+ kernel_size,
483
+ strides_t ,
484
+ paddings_t ,
485
+ ceil_mode,
486
+ exclusive,
487
+ data_format,
488
+ pooling_type,
489
+ global_pooling,
490
+ adaptive,
491
+ padding_algorithm,
492
+ in_x_grad)));
493
+ // aclnnAvgPool2dBackward do not support padding_algorithm = "SAME"
494
+ if (pooling_type == " max" || padding_algorithm == " SAME" ) {
495
+ return custom_kernel::AclopPool2dGradKernel<T, Context>(dev_ctx,
496
+ in_x,
497
+ out,
498
+ out_grad,
499
+ kernel_size,
500
+ strides_t ,
501
+ paddings_t ,
502
+ ceil_mode,
503
+ exclusive,
504
+ data_format,
505
+ pooling_type,
506
+ global_pooling,
507
+ adaptive,
508
+ padding_algorithm,
509
+ in_x_grad);
510
+ }
511
+
512
+ dev_ctx.template Alloc <T>(in_x_grad);
513
+ const bool channel_last = data_format == " NHWC" ;
514
+
515
+ std::vector<int > ksize (kernel_size.GetData ().begin (),
516
+ kernel_size.GetData ().end ());
517
+ auto strides = strides_t ;
518
+ auto paddings = paddings_t ;
519
+
520
+ // update paddings
521
+ auto in_x_dims = in_x.dims ();
522
+ auto out_dims = out.dims ();
523
+ phi::DDim data_dims;
524
+ phi::DDim out_data_dims;
525
+ std::vector<int64_t > ksize_vec = {static_cast <int64_t >(ksize[0 ]),
526
+ static_cast <int64_t >(ksize[1 ])};
527
+ std::vector<int64_t > strides_vec = {static_cast <int64_t >(strides[0 ]),
528
+ static_cast <int64_t >(strides[1 ])};
529
+
530
+ if (channel_last) {
531
+ data_dims = phi::slice_ddim (in_x_dims, 1 , in_x_dims.size () - 1 );
532
+ out_data_dims = phi::slice_ddim (out_dims, 1 , out_dims.size () - 1 );
533
+ } else {
534
+ data_dims = phi::slice_ddim (in_x_dims, 2 , in_x_dims.size ());
535
+ out_data_dims = phi::slice_ddim (out_dims, 2 , out_dims.size ());
536
+ }
537
+ if (data_dims[0 ] == 1 && data_dims[1 ] == 1 ) {
538
+ TensorCopy (dev_ctx, out_grad, false , in_x_grad);
539
+ return ;
540
+ }
541
+
542
+ UpdatePadding (&paddings,
543
+ global_pooling,
544
+ adaptive,
545
+ padding_algorithm,
546
+ data_dims,
547
+ strides,
548
+ ksize);
549
+
550
+ PADDLE_ENFORCE_LT (
551
+ std::max (paddings[0 ], paddings[1 ]),
552
+ ksize[0 ],
553
+ phi::errors::InvalidArgument (
554
+ " Paddings should be less than %d, but max(pads[0], pads[1]) is %d." ,
555
+ ksize[0 ],
556
+ std::max (paddings[0 ], paddings[1 ])));
557
+ PADDLE_ENFORCE_LT (
558
+ std::max (paddings[2 ], paddings[3 ]),
559
+ ksize[1 ],
560
+ phi::errors::InvalidArgument (
561
+ " Paddings should be less than %d, but max(pads[2], pads[3]) is %d." ,
562
+ ksize[1 ],
563
+ std::max (paddings[2 ], paddings[3 ])));
564
+
565
+ if (adaptive) {
566
+ strides_vec[0 ] = std::floor (data_dims[0 ] / out_data_dims[0 ]);
567
+ strides_vec[1 ] = std::floor (data_dims[1 ] / out_data_dims[1 ]);
568
+ ksize_vec[0 ] = data_dims[0 ] - ((out_data_dims[0 ] - 1 ) * strides_vec[0 ]);
569
+ ksize_vec[1 ] = data_dims[1 ] - ((out_data_dims[1 ] - 1 ) * strides_vec[1 ]);
570
+
571
+ for (auto & pad : paddings) {
572
+ pad = 0 ;
573
+ }
574
+ }
575
+ PADDLE_ENFORCE_LT (
576
+ std::max (strides[0 ], strides[1 ]),
577
+ 64 ,
578
+ phi::errors::InvalidArgument (" strides should be less than %d, but "
579
+ " max(strides[0], strides[1]) is %d." ,
580
+ 64 ,
581
+ std::max (strides[0 ], strides[1 ])));
582
+
583
+ bool count_include_pad = !exclusive;
584
+ int64_t divison_override = 0 ;
585
+ int8_t cube_math_type = 0 ;
586
+
587
+ std::vector<int64_t > paddings_new;
588
+ paddings_new = {static_cast <int64_t >(paddings[1 ]),
589
+ static_cast <int64_t >(paddings[2 ])};
590
+
591
+ phi::DenseTensor transformed_out_grad, transformed_in_x,
592
+ transformed_in_x_grad;
593
+ if (channel_last) {
594
+ std::vector<int > perm = {0 , 3 , 1 , 2 };
595
+ std::vector<int > out_grad_tensor_shape = {
596
+ out_grad.dims ()[0 ],
597
+ out_grad.dims ()[3 ],
598
+ out_grad.dims ()[1 ],
599
+ out_grad.dims ()[2 ],
600
+ };
601
+ transformed_out_grad.Resize (phi::make_ddim (out_grad_tensor_shape));
602
+ dev_ctx.template Alloc <T>(&transformed_out_grad);
603
+ custom_kernel::TransposeKernel<T, Context>(
604
+ dev_ctx, out_grad, perm, &transformed_out_grad);
605
+
606
+ std::vector<int > in_x_tensor_shape = {
607
+ in_x.dims ()[0 ],
608
+ in_x.dims ()[3 ],
609
+ in_x.dims ()[1 ],
610
+ in_x.dims ()[2 ],
611
+ };
612
+ transformed_in_x.Resize (phi::make_ddim (in_x_tensor_shape));
613
+ dev_ctx.template Alloc <T>(&transformed_in_x);
614
+ custom_kernel::TransposeKernel<T, Context>(
615
+ dev_ctx, in_x, perm, &transformed_in_x);
616
+
617
+ std::vector<int > in_x_grad_tensor_shape = {
618
+ in_x_grad->dims ()[0 ],
619
+ in_x_grad->dims ()[3 ],
620
+ in_x_grad->dims ()[1 ],
621
+ in_x_grad->dims ()[2 ],
622
+ };
623
+ transformed_in_x_grad.Resize (phi::make_ddim (in_x_grad_tensor_shape));
624
+ dev_ctx.template Alloc <T>(&transformed_in_x_grad);
625
+ } else {
626
+ transformed_out_grad = out_grad;
627
+ transformed_in_x = in_x;
628
+ transformed_in_x_grad = *in_x_grad;
629
+ }
630
+ if (pooling_type == " avg" ) {
631
+ EXEC_NPU_CMD (aclnnAvgPool2dBackward,
632
+ dev_ctx,
633
+ transformed_out_grad,
634
+ transformed_in_x,
635
+ ksize_vec,
636
+ strides_vec,
637
+ paddings_new,
638
+ ceil_mode,
639
+ count_include_pad,
640
+ divison_override,
641
+ cube_math_type,
642
+ transformed_in_x_grad);
643
+ }
644
+
645
+ if (channel_last) {
646
+ std::vector<int > perm;
647
+ perm = {0 , 2 , 3 , 1 };
648
+ custom_kernel::TransposeKernel<T, Context>(
649
+ dev_ctx, transformed_in_x_grad, perm, in_x_grad);
650
+ } else {
651
+ in_x_grad = &transformed_in_x_grad;
652
+ }
653
+ }
454
654
} // namespace custom_kernel
455
655
456
656
PD_REGISTER_PLUGIN_KERNEL (pool2d,
0 commit comments