@@ -599,8 +599,8 @@ void _layer_norm_kernel(
599
599
beta.defined () ? can_vectorize (beta_data, alignment) : true ;
600
600
601
601
if ((std::is_same_v<T, float > || std::is_same_v<T, at::Half> ||
602
- std::is_same_v<T, at::BFloat16>)&&N <=
603
- static_cast <int64_t >(1ULL << std::numeric_limits<float >::digits) &&
602
+ std::is_same_v<T, at::BFloat16>) &&
603
+ N <= static_cast <int64_t >(1ULL << std::numeric_limits<float >::digits) &&
604
604
N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma &&
605
605
can_vec_beta) {
606
606
launch_vectorized_layer_norm_kernel (
@@ -620,6 +620,157 @@ void _layer_norm_kernel(
620
620
}
621
621
}
622
622
623
+ template <
624
+ typename scalar_t ,
625
+ typename accscalar_t ,
626
+ typename mean_t ,
627
+ typename weight_t ,
628
+ bool have_gamma = true ,
629
+ bool have_beta = true >
630
+ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
631
+ void operator ()(sycl::nd_item<3 > item) const {
632
+ auto local_n = item.get_local_id (2 ); // [0, 32)
633
+ auto local_m = item.get_local_id (1 ); // [0, 8)
634
+ for (auto tile_id = item.get_global_id (0 );
635
+ tile_id < num_tile_n_ * num_tile_m_;
636
+ tile_id += item.get_group_range (0 )) {
637
+ auto tile_id_n = tile_id % num_tile_n_;
638
+ auto tile_id_m = tile_id / num_tile_n_;
639
+ auto tile_actual_row_base = tile_id_m * tile_size_m_;
640
+ auto tile_actual_col_base = tile_id_n * tile_size_n_;
641
+ auto actual_column = tile_actual_col_base + local_n;
642
+ if (actual_column < N_) {
643
+ // slm_row 0, 8, 16...56
644
+ for (auto slm_row = 0 ; slm_row < tile_size_m_ / elements_per_thread_;
645
+ slm_row += num_subgroup_) {
646
+ accscalar_t sum_beta = accscalar_t (0 );
647
+ accscalar_t sum_gamma = accscalar_t (0 );
648
+ // row 0, 128, 256, ...896
649
+ auto row = tile_actual_row_base + slm_row * elements_per_thread_;
650
+ for (int i = 0 ; i < elements_per_thread_; i++) {
651
+ // row_local: row + 0, 8, 16, ...120
652
+ auto row_local = row + i * num_subgroup_;
653
+ auto actual_row = row_local + local_m;
654
+ // TODO: try tree reduction here if accuracy loss
655
+ if (actual_row < M_) {
656
+ if constexpr (have_beta) {
657
+ sum_beta += static_cast <accscalar_t >(
658
+ dY_data_[actual_row * N_ + actual_column]);
659
+ }
660
+ if constexpr (have_gamma) {
661
+ sum_gamma += static_cast <accscalar_t >(
662
+ dY_data_[actual_row * N_ + actual_column]) *
663
+ (static_cast <accscalar_t >(
664
+ X_data_[actual_row * N_ + actual_column]) -
665
+ static_cast <accscalar_t >(mean_data_[actual_row])) *
666
+ static_cast <accscalar_t >(var_data_[actual_row]);
667
+ }
668
+ }
669
+ }
670
+
671
+ local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
672
+ sum_beta;
673
+ local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
674
+ sum_gamma;
675
+ }
676
+
677
+ // item.barrier(sycl_local_fence);
678
+ accscalar_t slm_sum_beta = accscalar_t (0 );
679
+ accscalar_t slm_sum_gamma = accscalar_t (0 );
680
+ // slm row 64, 8 subgroup, i = 0,2,4,6
681
+ // slm row 32, 8 subgroup, i = 0,2
682
+ // slm row 16, 8 subgroup, i = 0
683
+ for (int i = 0 ; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
684
+ i = i + 1 ) {
685
+ slm_sum_beta += local_sum_beta_
686
+ [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
687
+ slm_sum_gamma += local_sum_gamma_
688
+ [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
689
+ }
690
+ local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta;
691
+ local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma;
692
+ }
693
+ item.barrier (sycl_local_fence);
694
+ accscalar_t output_sum_beta = accscalar_t (0 );
695
+ accscalar_t output_sum_gamma = accscalar_t (0 );
696
+ if (local_m == 0 && actual_column < N_) {
697
+ for (int i = 0 ; i < num_subgroup_; i = i + 1 ) {
698
+ output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n];
699
+ output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n];
700
+ }
701
+ if constexpr (have_beta) {
702
+ db_data_[tile_id_m * N_ + actual_column] =
703
+ static_cast <weight_t >(output_sum_beta);
704
+ }
705
+
706
+ if constexpr (have_gamma) {
707
+ dg_data_[tile_id_m * N_ + actual_column] =
708
+ static_cast <weight_t >(output_sum_gamma);
709
+ }
710
+ }
711
+ }
712
+ }
713
+
714
+ void sycl_ker_config_convention (sycl::handler& cgh) {
715
+ local_sum_beta_ = sycl_local_acc_t <accscalar_t , 1 >(
716
+ sycl::range<1 >(tile_size_n_ * tile_size_m_ / elements_per_thread_),
717
+ cgh);
718
+ local_sum_gamma_ = sycl_local_acc_t <accscalar_t , 1 >(
719
+ sycl::range<1 >(tile_size_n_ * tile_size_m_ / elements_per_thread_),
720
+ cgh);
721
+ }
722
+
723
+ GammaBetaReduceFunctor (
724
+ const mean_t * mean_data,
725
+ const mean_t * var_data,
726
+ const scalar_t * dY_data,
727
+ const scalar_t * X_data,
728
+ weight_t * dg_block_data,
729
+ weight_t * db_block_data,
730
+ int64_t num_tile_m,
731
+ int64_t num_tile_n,
732
+ int64_t tile_size_m,
733
+ int64_t tile_size_n,
734
+ int64_t elements_per_thread,
735
+ int64_t num_subgroup,
736
+ int64_t M,
737
+ int64_t N)
738
+ : mean_data_(mean_data),
739
+ var_data_ (var_data),
740
+ dY_data_(dY_data),
741
+ X_data_(X_data),
742
+ dg_data_(dg_block_data),
743
+ db_data_(db_block_data),
744
+ num_tile_m_(num_tile_m),
745
+ num_tile_n_(num_tile_n),
746
+ tile_size_m_(tile_size_m),
747
+ tile_size_n_(tile_size_n),
748
+ elements_per_thread_(elements_per_thread),
749
+ num_subgroup_(num_subgroup),
750
+ M_(M),
751
+ N_(N),
752
+ local_sum_beta_(),
753
+ local_sum_gamma_() {}
754
+
755
+ private:
756
+ const mean_t * mean_data_;
757
+ const mean_t * var_data_;
758
+ const scalar_t * dY_data_;
759
+ const scalar_t * X_data_;
760
+ weight_t * dg_data_;
761
+ weight_t * db_data_;
762
+ int64_t num_tile_m_;
763
+ int64_t num_tile_n_;
764
+ int64_t tile_size_m_;
765
+ int64_t tile_size_n_;
766
+ int64_t elements_per_thread_;
767
+ int64_t num_subgroup_;
768
+ int64_t M_;
769
+ int64_t N_;
770
+ sycl_local_acc_t <accscalar_t , 1 > local_sum_beta_;
771
+ sycl_local_acc_t <accscalar_t , 1 > local_sum_gamma_;
772
+ };
773
+
623
774
template <
624
775
typename scalar_t ,
625
776
typename accscalar_t ,
@@ -900,10 +1051,209 @@ void _layer_norm_backward_kernel(
900
1051
norm, config, can_use_32bit_index);
901
1052
}
902
1053
}
903
-
904
1054
auto config_w = NormConfig (M, N, 0 , sizeof (scalar_t ));
905
- gamma_beta_bwd_simple_kernel<scalar_t , accscalar_t , mean_t , weight_t >(
906
- dY, X, mean_data, var_data, dgamma, dbeta, config_w);
1055
+ auto norm_config_global_size =
1056
+ config_w.workgroup_num * config_w.block_row * config_w.workgroup_size ;
1057
+ int thread_slots = syclGpuEuCount () * syclGpuHWThreadsPerEU ();
1058
+ // use two stage col reduction if norm config occupancy < 50%
1059
+ // TODO: we can relax this restriction in future for better perf
1060
+ bool use_two_stage_col_reduction =
1061
+ (dY.dtype () == kFloat || dY.dtype () == kBFloat16 ||
1062
+ dY.dtype () == kHalf ) &&
1063
+ norm_config_global_size / syclMaxSubGroupSize () * 2 <= thread_slots;
1064
+ // cuda uses condition M > 64 * 1024 && N / 32 < sm_count / 2 to parallelize
1065
+ // in the M dimension
1066
+ if (use_two_stage_col_reduction && M > 64 * 1024 &&
1067
+ N / 32 < syclGpuEuCount () / syclGpuEUCountPerSubslice () / 2 ) {
1068
+ const size_t local_size_x = 8 ;
1069
+ const size_t SIMD = 32 ;
1070
+ // workgroup size is 256
1071
+ // slm is 16KB, 64*32 float * 2
1072
+ // elements_per_thread is at least 16
1073
+ const int elements_per_thread = 16 ;
1074
+ int tile_size_m = 1024 ;
1075
+ int tile_size_n = N < 32 ? N : 32 ;
1076
+ int num_tile_m = (M + tile_size_m - 1 ) / tile_size_m;
1077
+ int num_tile_n = (N + tile_size_n - 1 ) / tile_size_n;
1078
+ bool adjust_m = true ;
1079
+ // for M = 64*1024, N = 1, we choose tile size (256, 16) on pvc
1080
+ // TODO: Consider tuning the tile size selection logic (tile_size_m, tile_size_n) and occupancy calculation
1081
+ for (auto i = 0 ; i < 3 ; i++) {
1082
+ // occupancy <= 50%
1083
+ if (num_tile_m * num_tile_n * local_size_x * SIMD /
1084
+ syclMaxSubGroupSize () * 2 <=
1085
+ thread_slots) {
1086
+ if (adjust_m) {
1087
+ tile_size_m /= 2 ;
1088
+ num_tile_m = (M + tile_size_m - 1 ) / tile_size_m;
1089
+ adjust_m = false ;
1090
+ } else {
1091
+ tile_size_n /= 2 ;
1092
+ num_tile_n = (N + tile_size_n - 1 ) / tile_size_n;
1093
+ adjust_m = true ;
1094
+ }
1095
+ } else {
1096
+ break ;
1097
+ }
1098
+ }
1099
+ // tile size can be (1024,32), (512,32), (512,16), (256, 16)
1100
+ // Modifying these parameters (num_subgroup, workgroup_size, tile_size, elements_per_thread)
1101
+ // will alter the kernel configuration, potentially affecting performance and behavior.
1102
+ const scalar_t * dY_data = dY.const_data_ptr <scalar_t >();
1103
+ const scalar_t * X_data = X.const_data_ptr <scalar_t >();
1104
+ weight_t * dg_data =
1105
+ dgamma.defined () ? dgamma.data_ptr <weight_t >() : nullptr ;
1106
+ weight_t * db_data = dbeta.defined () ? dbeta.data_ptr <weight_t >() : nullptr ;
1107
+ Tensor dgamma_blocks;
1108
+ Tensor dbeta_blocks;
1109
+ weight_t * dgamma_blocks_ptr = nullptr ;
1110
+ weight_t * dbeta_blocks_ptr = nullptr ;
1111
+ if (dgamma.defined ()) {
1112
+ auto options = dgamma.options ();
1113
+ // TODO: how to set dgamma_blocks dtype = float32?
1114
+ dgamma_blocks = at::empty ({num_tile_m, N}, options);
1115
+ dgamma_blocks_ptr = dgamma_blocks.data_ptr <weight_t >();
1116
+ }
1117
+ if (dbeta.defined ()) {
1118
+ auto options = dbeta.options ();
1119
+ dbeta_blocks = at::empty ({num_tile_m, N}, options);
1120
+ dbeta_blocks_ptr = dbeta_blocks.data_ptr <weight_t >();
1121
+ }
1122
+
1123
+ size_t num_workgroup =
1124
+ std::min (num_tile_m * num_tile_n, static_cast <int >(thread_slots / local_size_x));
1125
+ if (dgamma.defined () && dbeta.defined ()) {
1126
+ GammaBetaReduceFunctor<
1127
+ scalar_t ,
1128
+ accscalar_t ,
1129
+ mean_t ,
1130
+ weight_t ,
1131
+ true ,
1132
+ true >
1133
+ kfn (mean_data,
1134
+ var_data,
1135
+ dY_data,
1136
+ X_data,
1137
+ dgamma_blocks_ptr,
1138
+ dbeta_blocks_ptr,
1139
+ num_tile_m,
1140
+ num_tile_n,
1141
+ tile_size_m,
1142
+ tile_size_n,
1143
+ elements_per_thread,
1144
+ local_size_x,
1145
+ M,
1146
+ N);
1147
+
1148
+ sycl_kernel_submit<
1149
+ GammaBetaReduceFunctor<
1150
+ scalar_t ,
1151
+ accscalar_t ,
1152
+ mean_t ,
1153
+ weight_t ,
1154
+ true ,
1155
+ true >,
1156
+ 3 >(
1157
+ {num_workgroup,
1158
+ local_size_x,
1159
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1160
+ {1 ,
1161
+ local_size_x,
1162
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1163
+ getCurrentSYCLQueue (),
1164
+ kfn);
1165
+ dgamma = dgamma_blocks.sum (0 );
1166
+ dbeta = dbeta_blocks.sum (0 );
1167
+ } else if (dgamma.defined () && !dbeta.defined ()) {
1168
+ GammaBetaReduceFunctor<
1169
+ scalar_t ,
1170
+ accscalar_t ,
1171
+ mean_t ,
1172
+ weight_t ,
1173
+ true ,
1174
+ false >
1175
+ kfn (mean_data,
1176
+ var_data,
1177
+ dY_data,
1178
+ X_data,
1179
+ dgamma_blocks_ptr,
1180
+ dbeta_blocks_ptr,
1181
+ num_tile_m,
1182
+ num_tile_n,
1183
+ tile_size_m,
1184
+ tile_size_n,
1185
+ elements_per_thread,
1186
+ local_size_x,
1187
+ M,
1188
+ N);
1189
+
1190
+ sycl_kernel_submit<
1191
+ GammaBetaReduceFunctor<
1192
+ scalar_t ,
1193
+ accscalar_t ,
1194
+ mean_t ,
1195
+ weight_t ,
1196
+ true ,
1197
+ false >,
1198
+ 3 >(
1199
+ {num_workgroup,
1200
+ local_size_x,
1201
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1202
+ {1 ,
1203
+ local_size_x,
1204
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1205
+ getCurrentSYCLQueue (),
1206
+ kfn);
1207
+ dgamma = dgamma_blocks.sum (0 );
1208
+ } else if (!dgamma.defined () && dbeta.defined ()) {
1209
+ GammaBetaReduceFunctor<
1210
+ scalar_t ,
1211
+ accscalar_t ,
1212
+ mean_t ,
1213
+ weight_t ,
1214
+ false ,
1215
+ true >
1216
+ kfn (mean_data,
1217
+ var_data,
1218
+ dY_data,
1219
+ X_data,
1220
+ dgamma_blocks_ptr,
1221
+ dbeta_blocks_ptr,
1222
+ num_tile_m,
1223
+ num_tile_n,
1224
+ tile_size_m,
1225
+ tile_size_n,
1226
+ elements_per_thread,
1227
+ local_size_x,
1228
+ M,
1229
+ N);
1230
+
1231
+ sycl_kernel_submit<
1232
+ GammaBetaReduceFunctor<
1233
+ scalar_t ,
1234
+ accscalar_t ,
1235
+ mean_t ,
1236
+ weight_t ,
1237
+ false ,
1238
+ true >,
1239
+ 3 >(
1240
+ {num_workgroup,
1241
+ local_size_x,
1242
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1243
+ {1 ,
1244
+ local_size_x,
1245
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1246
+ getCurrentSYCLQueue (),
1247
+ kfn);
1248
+ dbeta = dbeta_blocks.sum (0 );
1249
+ } else {
1250
+ return ;
1251
+ }
1252
+
1253
+ } else {
1254
+ gamma_beta_bwd_simple_kernel<scalar_t , accscalar_t , mean_t , weight_t >(
1255
+ dY, X, mean_data, var_data, dgamma, dbeta, config_w);
1256
+ }
907
1257
}
908
1258
909
1259
void layer_norm_kernel (
0 commit comments