11
11
12
12
#include < ATen/native/xpu/sycl/Atomics.h>
13
13
#include < ATen/native/xpu/sycl/BatchKernel.h>
14
+ #include < ATen/native/xpu/sycl/MemoryAccess.h>
14
15
#include < ATen/native/xpu/sycl/NumericLimits.h>
16
+
15
17
#include < comm/Runtime.h>
16
18
#include < comm/SYCLHelpers.h>
17
19
@@ -151,6 +153,119 @@ struct MaxPool2dKernelFunctor {
151
153
BatchKernelConfig cfg_;
152
154
};
153
155
156
+ template <typename scalar_t , typename vec_t , int vec_size>
157
+ struct MaxPool2dChannelLastVec {
158
+ void operator ()(sycl::nd_item<1 > item) const {
159
+ for (auto outputIndex = item.get_global_linear_id ();
160
+ outputIndex < numBatch_ * stride_ / vec_size;
161
+ outputIndex += item.get_local_range (0 ) * item.get_group_range (0 )) {
162
+ int batch = outputIndex / (stride_ / vec_size);
163
+ int plane, outputH, outputW;
164
+ int64_t load_offset, store_offset;
165
+ plane = outputIndex % (numPlane_ / vec_size);
166
+ outputH =
167
+ outputIndex / (numPlane_ / vec_size) / outputSizeW_ % outputSizeH_;
168
+ outputW = outputIndex / (numPlane_ / vec_size) % outputSizeW_;
169
+ store_offset = outputIndex;
170
+
171
+ vec_t maxVal_vec;
172
+ #pragma unroll
173
+ for (int i = 0 ; i < vec_size; i++) {
174
+ maxVal_vec[i] = at::numeric_limits<scalar_t >::lower_bound ();
175
+ }
176
+ int64_t maxIndex[vec_size];
177
+ for (int i = 0 ; i < vec_size; i++) {
178
+ maxIndex[i] = int64_t (-1 );
179
+ }
180
+ int StartH = outputH * dH_ - padH_;
181
+ int StartW = outputW * dW_ - padW_;
182
+ int EndH = std::min (StartH + (kH_ - 1 ) * dilationH_ + 1 , inputSizeH_);
183
+ int EndW = std::min (StartW + (kW_ - 1 ) * dilationW_ + 1 , inputSizeW_);
184
+ while (StartH < 0 )
185
+ StartH += dilationH_;
186
+ while (StartW < 0 )
187
+ StartW += dilationW_;
188
+ for (int h = StartH; h < EndH; h += dilationH_) {
189
+ for (int w = StartW; w < EndW; w += dilationW_) {
190
+ load_offset = batch * inputSizeH_*inputSizeW_*numPlane_ / vec_size + plane +
191
+ h * inputSizeW_ * numPlane_ / vec_size + w * numPlane_ / vec_size;
192
+ vec_t val_vec = input_vec_[load_offset];
193
+ #pragma unroll
194
+ for (int i = 0 ; i < vec_size; i++) {
195
+ if ((static_cast <scalar_t >(val_vec[i]) > maxVal_vec[i]) ||
196
+ at::_isnan (val_vec[i])) {
197
+ maxIndex[i] = h * inputSizeW_ + w;
198
+ maxVal_vec[i] = static_cast <scalar_t >(val_vec[i]);
199
+ }
200
+ }
201
+ }
202
+ }
203
+ #pragma unroll
204
+ for (int i = 0 ; i < vec_size; i++) {
205
+ indices_[store_offset * vec_size + i] = maxIndex[i];
206
+ }
207
+ output_vec_[store_offset] = maxVal_vec;
208
+ }
209
+ }
210
+ MaxPool2dChannelLastVec (
211
+ vec_t * output_vec,
212
+ int64_t * indices,
213
+ const vec_t * input_vec,
214
+ int numBatch,
215
+ int numPlane,
216
+ int inputSizeH,
217
+ int inputSizeW,
218
+ int outputSizeH,
219
+ int outputSizeW,
220
+ int kH ,
221
+ int kW ,
222
+ int dH,
223
+ int dW,
224
+ int padH,
225
+ int padW,
226
+ int dilationH,
227
+ int dilationW,
228
+ int stride)
229
+ : output_vec_(output_vec),
230
+ indices_ (indices),
231
+ input_vec_(input_vec),
232
+ numBatch_(numBatch),
233
+ numPlane_(numPlane),
234
+ inputSizeH_(inputSizeH),
235
+ inputSizeW_(inputSizeW),
236
+ outputSizeH_(outputSizeH),
237
+ outputSizeW_(outputSizeW),
238
+ kH_(kH ),
239
+ kW_(kW ),
240
+ dH_(dH),
241
+ dW_(dW),
242
+ padH_(padH),
243
+ padW_(padW),
244
+ dilationH_(dilationH),
245
+ dilationW_(dilationW),
246
+ stride_(stride) {}
247
+
248
+ private:
249
+ vec_t * output_vec_;
250
+ int64_t * indices_;
251
+ const vec_t * input_vec_;
252
+ int numBatch_;
253
+ int numPlane_;
254
+ int inputSizeH_;
255
+ int inputSizeW_;
256
+ int outputSizeH_;
257
+ int outputSizeW_;
258
+ int kH_ ;
259
+ int kW_ ;
260
+ int dH_;
261
+ int dW_;
262
+ int padH_;
263
+ int padW_;
264
+ int dilationH_;
265
+ int dilationW_;
266
+ int stride_;
267
+ };
268
+
154
269
template <typename scalar_t , bool is_channels_last>
155
270
struct MaxPool2dBackwardKernelFunctor {
156
271
void operator ()(sycl::nd_item<2 > item) const {
@@ -349,6 +464,56 @@ struct MaxPool2dBackwardDeterministicKernelFunctor {
349
464
BatchKernelConfig cfg_;
350
465
};
351
466
467
+ #define LAUNCH_MAXPOOL_CHANNEL_LAST_VEC ( \
468
+ scalar_t , \
469
+ vec_size, \
470
+ num_wg, \
471
+ wg_size, \
472
+ queue, \
473
+ output, \
474
+ indices, \
475
+ input, \
476
+ numBatch, \
477
+ numPlane, \
478
+ inputSizeH, \
479
+ inputSizeW, \
480
+ outputSizeH, \
481
+ outputSizeW, \
482
+ kH , \
483
+ kW , \
484
+ dH, \
485
+ dW, \
486
+ padH, \
487
+ padW, \
488
+ dilationH, \
489
+ dilationW, \
490
+ stride) \
491
+ { \
492
+ using vec_t = memory::aligned_vector<scalar_t , vec_size>; \
493
+ vec_t * output_vec = reinterpret_cast <vec_t *>(output); \
494
+ const vec_t * input_vec = reinterpret_cast <const vec_t *>(input); \
495
+ auto kfn = MaxPool2dChannelLastVec<scalar_t , vec_t , vec_size>( \
496
+ output_vec, \
497
+ indices, \
498
+ input_vec, \
499
+ numBatch, \
500
+ numPlane, \
501
+ inputSizeH, \
502
+ inputSizeW, \
503
+ outputSizeH, \
504
+ outputSizeW, \
505
+ kH , \
506
+ kW , \
507
+ dH, \
508
+ dW, \
509
+ padH, \
510
+ padW, \
511
+ dilationH, \
512
+ dilationW, \
513
+ stride); \
514
+ sycl_kernel_submit (num_wg * wg_size, wg_size, queue, kfn); \
515
+ }
516
+
352
517
template <typename scalar_t , bool is_channels_last>
353
518
void launch_max_pool2d_kernel (
354
519
scalar_t * output,
@@ -368,11 +533,114 @@ void launch_max_pool2d_kernel(
368
533
int padW,
369
534
int dilationH,
370
535
int dilationW) {
371
- using KernelClass = MaxPool2dKernelFunctor<scalar_t , is_channels_last>;
372
-
373
536
auto & queue = at::xpu::getCurrentSYCLQueue ();
374
537
int outputSize = numBatch * numPlane * outputSizeH * outputSizeW;
375
538
int stride = numPlane * outputSizeH * outputSizeW;
539
+ int vec_size = 1 ;
540
+ int thread_slots = syclGpuEuCount () * syclGpuHWThreadsPerEU ();
541
+ int num_sub_wg;
542
+ auto wg_size = syclDeviceMaxWorkGroupSize ();
543
+ int64_t num_wg;
544
+ if constexpr (is_channels_last) {
545
+ for (vec_size =
546
+ std::min (8 , memory::can_vectorize_up_to<scalar_t >((char *)input));
547
+ vec_size >= 1 ;
548
+ vec_size /= 2 ) {
549
+ if (numPlane % vec_size != 0 ) {
550
+ continue ;
551
+ }
552
+ num_sub_wg = outputSize / vec_size / syclMaxSubGroupSize ();
553
+ if (2 * num_sub_wg > thread_slots) {
554
+ int total_thread = outputSize / vec_size;
555
+ num_wg = (total_thread + wg_size - 1 ) / wg_size;
556
+ break ;
557
+ }
558
+ }
559
+ switch (vec_size) {
560
+ case 8 :
561
+ LAUNCH_MAXPOOL_CHANNEL_LAST_VEC (
562
+ scalar_t ,
563
+ 8 ,
564
+ num_wg,
565
+ wg_size,
566
+ queue,
567
+ output,
568
+ indices,
569
+ input,
570
+ numBatch,
571
+ numPlane,
572
+ inputSizeH,
573
+ inputSizeW,
574
+ outputSizeH,
575
+ outputSizeW,
576
+ kH ,
577
+ kW ,
578
+ dH,
579
+ dW,
580
+ padH,
581
+ padW,
582
+ dilationH,
583
+ dilationW,
584
+ stride);
585
+ return ;
586
+ case 4 :
587
+ LAUNCH_MAXPOOL_CHANNEL_LAST_VEC (
588
+ scalar_t ,
589
+ 4 ,
590
+ num_wg,
591
+ wg_size,
592
+ queue,
593
+ output,
594
+ indices,
595
+ input,
596
+ numBatch,
597
+ numPlane,
598
+ inputSizeH,
599
+ inputSizeW,
600
+ outputSizeH,
601
+ outputSizeW,
602
+ kH ,
603
+ kW ,
604
+ dH,
605
+ dW,
606
+ padH,
607
+ padW,
608
+ dilationH,
609
+ dilationW,
610
+ stride);
611
+ return ;
612
+ case 2 :
613
+ LAUNCH_MAXPOOL_CHANNEL_LAST_VEC (
614
+ scalar_t ,
615
+ 2 ,
616
+ num_wg,
617
+ wg_size,
618
+ queue,
619
+ output,
620
+ indices,
621
+ input,
622
+ numBatch,
623
+ numPlane,
624
+ inputSizeH,
625
+ inputSizeW,
626
+ outputSizeH,
627
+ outputSizeW,
628
+ kH ,
629
+ kW ,
630
+ dH,
631
+ dW,
632
+ padH,
633
+ padW,
634
+ dilationH,
635
+ dilationW,
636
+ stride);
637
+ return ;
638
+ default :
639
+ break ;
640
+ };
641
+ }
642
+ using KernelClass = MaxPool2dKernelFunctor<scalar_t , is_channels_last>;
643
+
376
644
BatchKernelConfig cfg = BatchKernelConfig::make_config<KernelClass>(
377
645
1 , outputSize, 1 , 1 , true , BatchKernelConfig::Policy::pAdaptive);
378
646
auto kfn = KernelClass (
@@ -704,6 +972,6 @@ void max_pool2d_with_indices_backward_kernel(
704
972
}
705
973
706
974
} // namespace at::native::xpu
707
-
975
+ # undef LAUNCH_MAXPOOL_CHANNEL_LAST_VEC
708
976
#pragma GCC diagnostic pop
709
977
#pragma clang diagnostic pop
0 commit comments