@@ -626,84 +626,96 @@ template <
626
626
typename mean_t ,
627
627
typename weight_t >
628
628
struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
629
- void operator ()(sycl::nd_item<1 > item) const {
630
- auto local_id = item.get_local_id (0 ); // [0, 255]
631
- auto group_id = item.get_group (0 ); // tile id
632
-
633
- auto tile_row_base = group_id * tile_;
634
-
635
- auto col = local_id % N_; // [0, N-1]
636
- auto lane = local_id / N_; // [0, 7]
637
-
638
- // slm_row 0, 8, 16...56
639
- for (auto slm_row = 0 ; slm_row < tile_ / elements_per_thread_;
640
- slm_row += num_subgroup_) {
641
- accscalar_t sum_beta[8 ] = {accscalar_t (0 )};
642
- accscalar_t sum_gamma[8 ] = {accscalar_t (0 )};
643
- // row: tile_row_base + 0, 128, 256, ...896
644
- auto row = tile_row_base + slm_row * elements_per_thread_;
645
- for (int i = 0 ; i < elements_per_thread_; i++) {
646
- // row_local: row + 0, 8, 16, ...120
647
- auto row_local = row + i * num_subgroup_;
648
- auto actual_row = row_local + lane;
649
- // TODO: tree reduction here for better acc
650
- if (actual_row < M_ && db_data_ != nullptr ) {
651
- sum_beta[i / 2 ] +=
652
- static_cast <accscalar_t >(dY_data_[actual_row * N_ + col]);
629
+ void operator ()(sycl::nd_item<3 > item) const {
630
+ auto local_n = item.get_local_id (2 ); // [0, 32)
631
+ auto local_m = item.get_local_id (1 ); // [0, 8)
632
+ auto tile_id = item.get_global_id (0 ); // tile id
633
+ auto tile_id_n = tile_id % num_tile_n_;
634
+ auto tile_id_m = tile_id / num_tile_n_;
635
+ auto tile_actual_row_base = tile_id_m * tile_size_m_;
636
+ auto tile_actual_col_base = tile_id_n * tile_size_n_;
637
+ auto actual_column = tile_actual_col_base + local_n;
638
+ if (actual_column < N_) {
639
+ // slm_row 0, 8, 16...56
640
+ for (auto slm_row = 0 ; slm_row < tile_size_m_ / elements_per_thread_;
641
+ slm_row += num_subgroup_) {
642
+ accscalar_t sum_beta[8 ] = {accscalar_t (0 )};
643
+ accscalar_t sum_gamma[8 ] = {accscalar_t (0 )};
644
+ // row 0, 128, 256, ...896
645
+ auto row = tile_actual_row_base + slm_row * elements_per_thread_;
646
+ for (int i = 0 ; i < elements_per_thread_; i++) {
647
+ // row_local: row + 0, 8, 16, ...120
648
+ auto row_local = row + i * num_subgroup_;
649
+ auto actual_row = row_local + local_m;
650
+ // TODO: tree reduction here for better acc
651
+ if (actual_row < M_ && db_data_ != nullptr ) {
652
+ sum_beta[i / 2 ] += static_cast <accscalar_t >(
653
+ dY_data_[actual_row * N_ + actual_column]);
654
+ }
655
+ if (actual_row < M_ && dg_data_ != nullptr ) {
656
+ sum_gamma[i / 2 ] += static_cast <accscalar_t >(
657
+ dY_data_[actual_row * N_ + actual_column]) *
658
+ (static_cast <accscalar_t >(
659
+ X_data_[actual_row * N_ + actual_column]) -
660
+ static_cast <accscalar_t >(mean_data_[actual_row])) *
661
+ static_cast <accscalar_t >(var_data_[actual_row]);
662
+ }
653
663
}
654
- if (actual_row < M_ && dg_data_ != nullptr ) {
655
- sum_gamma[i / 2 ] +=
656
- static_cast <accscalar_t >(dY_data_[actual_row * N_ + col]) *
657
- (static_cast <accscalar_t >(X_data_[actual_row * N_ + col]) -
658
- static_cast <accscalar_t >(mean_data_[actual_row])) *
659
- static_cast <accscalar_t >(var_data_[actual_row]);
664
+ for (int i = 0 ; i < 4 ; i++) {
665
+ sum_beta[i] += sum_beta[i + 4 ];
666
+ sum_gamma[i] += sum_gamma[i + 4 ];
660
667
}
661
- }
662
- for (int i = 0 ; i < 4 ; i++) {
663
- sum_beta[i] += sum_beta[i + 4 ];
664
- sum_gamma[i] += sum_gamma[i + 4 ];
665
- }
666
668
667
- local_sum_beta_[slm_row * N_ + local_id] =
668
- (sum_beta[0 ] + sum_beta[1 ]) + (sum_beta[2 ] + sum_beta[3 ]);
669
- local_sum_gamma_[slm_row * N_ + local_id] =
670
- (sum_gamma[0 ] + sum_gamma[1 ]) + (sum_gamma[2 ] + sum_gamma[3 ]);
671
- }
672
- item.barrier (sycl_local_fence);
669
+ local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
670
+ (sum_beta[0 ] + sum_beta[1 ]) + (sum_beta[2 ] + sum_beta[3 ]);
671
+ local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
672
+ (sum_gamma[0 ] + sum_gamma[1 ]) + (sum_gamma[2 ] + sum_gamma[3 ]);
673
+ }
673
674
674
- accscalar_t slm_sum_beta[4 ] = {accscalar_t (0 )};
675
- accscalar_t slm_sum_gamma[4 ] = {accscalar_t (0 )};
676
- for (int i = 0 ; i < tile_ / elements_per_thread_ / num_subgroup_;
677
- i = i + 2 ) {
678
- slm_sum_beta[i / 2 ] =
679
- local_sum_beta_[(i * num_subgroup_ + lane) * N_ + col] +
680
- local_sum_beta_[((i + 1 ) * num_subgroup_ + lane) * N_ + col];
681
- slm_sum_gamma[i / 2 ] =
682
- local_sum_gamma_[(i * num_subgroup_ + lane) * N_ + col] +
683
- local_sum_gamma_[((i + 1 ) * num_subgroup_ + lane) * N_ + col];
675
+ // item.barrier(sycl_local_fence);
676
+ accscalar_t slm_sum_beta[4 ] = {accscalar_t (0 )};
677
+ accscalar_t slm_sum_gamma[4 ] = {accscalar_t (0 )};
678
+ // slm row 64, 8 subgroup, i = 0,2,4,6
679
+ // slm row 32, 8 subgroup, i = 0,2
680
+ // slm row 16, 8 subgroup, i = 0
681
+ for (int i = 0 ; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
682
+ i = i + 2 ) {
683
+ slm_sum_beta[i / 2 ] =
684
+ local_sum_beta_
685
+ [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n] +
686
+ local_sum_beta_
687
+ [((i + 1 ) * num_subgroup_ + local_m) * tile_size_n_ + local_n];
688
+ slm_sum_gamma[i / 2 ] =
689
+ local_sum_gamma_
690
+ [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n] +
691
+ local_sum_gamma_
692
+ [((i + 1 ) * num_subgroup_ + local_m) * tile_size_n_ + local_n];
693
+ }
694
+ local_sum_beta_[local_m * tile_size_n_ + local_n] =
695
+ (slm_sum_beta[0 ] + slm_sum_beta[1 ]) +
696
+ (slm_sum_beta[2 ] + slm_sum_beta[3 ]);
697
+ local_sum_gamma_[local_m * tile_size_n_ + local_n] =
698
+ (slm_sum_gamma[0 ] + slm_sum_gamma[1 ]) +
699
+ (slm_sum_gamma[2 ] + slm_sum_gamma[3 ]);
684
700
}
685
- local_sum_beta_[lane * N_ + col] = (slm_sum_beta[0 ] + slm_sum_beta[1 ]) +
686
- (slm_sum_beta[2 ] + slm_sum_beta[3 ]);
687
- local_sum_gamma_[lane * N_ + col] = (slm_sum_gamma[0 ] + slm_sum_gamma[1 ]) +
688
- (slm_sum_gamma[2 ] + slm_sum_gamma[3 ]);
689
701
item.barrier (sycl_local_fence);
690
702
accscalar_t output_sum_beta[4 ] = {accscalar_t (0 )};
691
703
accscalar_t output_sum_gamma[4 ] = {accscalar_t (0 )};
692
- if (local_id < N_) {
704
+ if (local_m == 0 && actual_column < N_) {
693
705
for (int i = 0 ; i < num_subgroup_; i = i + 2 ) {
694
- output_sum_beta[i / 2 ] =
695
- local_sum_beta_[i * N_ + col] + local_sum_beta_[ (i + 1 ) * N_ + col ];
696
- output_sum_gamma[i / 2 ] = local_sum_gamma_[i * N_ + col ] +
697
- local_sum_gamma_[(i + 1 ) * N_ + col ];
706
+ output_sum_beta[i / 2 ] = local_sum_beta_[i * tile_size_n_ + local_n] +
707
+ local_sum_beta_[(i + 1 ) * tile_size_n_ + local_n ];
708
+ output_sum_gamma[i / 2 ] = local_sum_gamma_[i * tile_size_n_ + local_n ] +
709
+ local_sum_gamma_[(i + 1 ) * tile_size_n_ + local_n ];
698
710
}
699
711
if (db_data_ != nullptr )
700
- db_data_[group_id * N_ + col ] =
712
+ db_data_[tile_id_m * tile_size_n_ + actual_column ] =
701
713
(static_cast <weight_t >(output_sum_beta[0 ]) +
702
714
static_cast <weight_t >(output_sum_beta[1 ])) +
703
715
(static_cast <weight_t >(output_sum_beta[2 ]) +
704
716
static_cast <weight_t >(output_sum_beta[3 ]));
705
717
if (dg_data_ != nullptr )
706
- dg_data_[group_id * N_ + col ] =
718
+ dg_data_[tile_id_m * tile_size_n_ + actual_column ] =
707
719
(static_cast <weight_t >(output_sum_gamma[0 ]) +
708
720
static_cast <weight_t >(output_sum_gamma[1 ])) +
709
721
(static_cast <weight_t >(output_sum_gamma[2 ]) +
@@ -713,9 +725,11 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
713
725
714
726
void sycl_ker_config_convention (sycl::handler& cgh) {
715
727
local_sum_beta_ = sycl_local_acc_t <accscalar_t , 1 >(
716
- sycl::range<1 >(N_ * tile_ / elements_per_thread_), cgh);
728
+ sycl::range<1 >(tile_size_n_ * tile_size_m_ / elements_per_thread_),
729
+ cgh);
717
730
local_sum_gamma_ = sycl_local_acc_t <accscalar_t , 1 >(
718
- sycl::range<1 >(N_ * tile_ / elements_per_thread_), cgh);
731
+ sycl::range<1 >(tile_size_n_ * tile_size_m_ / elements_per_thread_),
732
+ cgh);
719
733
}
720
734
721
735
GammaBetaReduceFunctor (
@@ -725,7 +739,10 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
725
739
const scalar_t * X_data,
726
740
weight_t * dg_block_data,
727
741
weight_t * db_block_data,
728
- int64_t tile,
742
+ int64_t num_tile_m,
743
+ int64_t num_tile_n,
744
+ int64_t tile_size_m,
745
+ int64_t tile_size_n,
729
746
int64_t elements_per_thread,
730
747
int64_t num_subgroup,
731
748
int64_t M,
@@ -736,7 +753,10 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
736
753
X_data_(X_data),
737
754
dg_data_(dg_block_data),
738
755
db_data_(db_block_data),
739
- tile_(tile),
756
+ num_tile_m_(num_tile_m),
757
+ num_tile_n_(num_tile_n),
758
+ tile_size_m_(tile_size_m),
759
+ tile_size_n_(tile_size_n),
740
760
elements_per_thread_(elements_per_thread),
741
761
num_subgroup_(num_subgroup),
742
762
M_(M),
@@ -751,7 +771,10 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
751
771
const scalar_t * X_data_;
752
772
weight_t * dg_data_;
753
773
weight_t * db_data_;
754
- int64_t tile_;
774
+ int64_t num_tile_m_;
775
+ int64_t num_tile_n_;
776
+ int64_t tile_size_m_;
777
+ int64_t tile_size_n_;
755
778
int64_t elements_per_thread_;
756
779
int64_t num_subgroup_;
757
780
int64_t M_;
@@ -1040,12 +1063,52 @@ void _layer_norm_backward_kernel(
1040
1063
norm, config, can_use_32bit_index);
1041
1064
}
1042
1065
}
1043
- if (N <= 32 && M > 64 * 1024 ) {
1044
- const int num_subgroup = 8 ;
1045
- const int workgroup_size = N * num_subgroup;
1046
- const int tile_size = 1024 ;
1066
+ auto config_w = NormConfig (M, N, 0 , sizeof (scalar_t ));
1067
+ auto norm_config_global_size =
1068
+ config_w.workgroup_num * config_w.block_row * config_w.workgroup_size ;
1069
+ int thread_slots = syclGpuEuCount () * syclGpuHWThreadsPerEU ();
1070
+ // use two stage col reduction if norm config occupancy < 50%
1071
+ // TODO: we can releax this restriction in future for better perf
1072
+ bool use_two_stage_col_reduction =
1073
+ (dY.dtype () == kFloat || dY.dtype () == kBFloat16 ||
1074
+ dY.dtype () == kHalf ) &&
1075
+ norm_config_global_size / syclMaxSubGroupSize () * 2 <= thread_slots;
1076
+ // cuda uses condition M > 64 * 1024 && N / 32 < sm_count / 2 to parallelize
1077
+ // in the M dimension
1078
+ if (use_two_stage_col_reduction && M > 64 * 1024 &&
1079
+ N / 32 < syclGpuEuCount () / syclGpuEUCountPerSubslice () / 2 ) {
1080
+ const size_t local_size_x = 8 ;
1081
+ const size_t SIMD = 32 ;
1082
+ // workgroup size is 256
1083
+ // slm is 16KB, 64*32 float * 2
1084
+ // elements_per_thread is at least 16
1047
1085
const int elements_per_thread = 16 ;
1048
- const int num_tile = (M + tile_size - 1 ) / tile_size;
1086
+ int tile_size_m = 1024 ;
1087
+ int tile_size_n = N < 32 ? N : 32 ;
1088
+ int num_tile_m = (M + tile_size_m - 1 ) / tile_size_m;
1089
+ int num_tile_n = (N + tile_size_n - 1 ) / tile_size_n;
1090
+ bool adjust_m = true ;
1091
+ // for M = 64*1024, N = 1, we choose tile size (256, 16) on pvc
1092
+ // TODO: we can tune these conditions in future
1093
+ for (auto i = 0 ; i < 3 ; i++) {
1094
+ // occupancy <= 50%
1095
+ if (num_tile_m * num_tile_n * local_size_x * SIMD /
1096
+ syclMaxSubGroupSize () * 2 <=
1097
+ thread_slots) {
1098
+ if (adjust_m) {
1099
+ tile_size_m /= 2 ;
1100
+ num_tile_m = (M + tile_size_m - 1 ) / tile_size_m;
1101
+ adjust_m = false ;
1102
+ } else {
1103
+ tile_size_n /= 2 ;
1104
+ num_tile_n = (N + tile_size_n - 1 ) / tile_size_n;
1105
+ adjust_m = true ;
1106
+ }
1107
+ } else {
1108
+ break ;
1109
+ }
1110
+ }
1111
+ // tile size can be (1024,32), (512,32), (512,16), (256, 16)
1049
1112
// Change these parameters will cause changes in kernel
1050
1113
const scalar_t * dY_data = dY.const_data_ptr <scalar_t >();
1051
1114
const scalar_t * X_data = X.const_data_ptr <scalar_t >();
@@ -1059,32 +1122,37 @@ void _layer_norm_backward_kernel(
1059
1122
if (dgamma.defined ()) {
1060
1123
auto options = dgamma.options ();
1061
1124
// TODO: how to set dgamma_blocks dtype = float32?
1062
- dgamma_blocks = at::empty ({num_tile , N}, options);
1125
+ dgamma_blocks = at::empty ({num_tile_m , N}, options);
1063
1126
dgamma_blocks_ptr = dgamma_blocks.data_ptr <weight_t >();
1064
1127
}
1065
1128
if (dbeta.defined ()) {
1066
1129
auto options = dbeta.options ();
1067
- dbeta_blocks = at::empty ({num_tile , N}, options);
1130
+ dbeta_blocks = at::empty ({num_tile_m , N}, options);
1068
1131
dbeta_blocks_ptr = dbeta_blocks.data_ptr <weight_t >();
1069
1132
}
1070
1133
1071
- int num_workgroup = (M + tile_size - 1 ) / tile_size ;
1134
+ size_t num_workgroup = num_tile_m * num_tile_n ;
1072
1135
GammaBetaReduceFunctor<scalar_t , accscalar_t , mean_t , weight_t > kfn (
1073
1136
mean_data,
1074
1137
var_data,
1075
1138
dY_data,
1076
1139
X_data,
1077
1140
dgamma_blocks_ptr,
1078
1141
dbeta_blocks_ptr,
1079
- tile_size,
1142
+ num_tile_m,
1143
+ num_tile_n,
1144
+ tile_size_m,
1145
+ tile_size_n,
1080
1146
elements_per_thread,
1081
- num_subgroup ,
1147
+ local_size_x ,
1082
1148
M,
1083
1149
N);
1084
1150
1085
- sycl_kernel_submit (
1086
- workgroup_size * num_workgroup,
1087
- workgroup_size,
1151
+ sycl_kernel_submit<
1152
+ GammaBetaReduceFunctor<scalar_t , accscalar_t , mean_t , weight_t >,
1153
+ 3 >(
1154
+ {num_workgroup, local_size_x, static_cast <size_t >(N < SIMD ? N : SIMD)},
1155
+ {1 , local_size_x, static_cast <size_t >(N < SIMD ? N : SIMD)},
1088
1156
getCurrentSYCLQueue (),
1089
1157
kfn);
1090
1158
dgamma = dgamma_blocks.sum (0 );
0 commit comments