@@ -624,7 +624,9 @@ template <
624
624
typename scalar_t ,
625
625
typename accscalar_t ,
626
626
typename mean_t ,
627
- typename weight_t >
627
+ typename weight_t ,
628
+ bool have_gamma = true ,
629
+ bool have_beta = true >
628
630
struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
629
631
void operator ()(sycl::nd_item<3 > item) const {
630
632
auto local_n = item.get_local_id (2 ); // [0, 32)
@@ -639,87 +641,70 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
639
641
// slm_row 0, 8, 16...56
640
642
for (auto slm_row = 0 ; slm_row < tile_size_m_ / elements_per_thread_;
641
643
slm_row += num_subgroup_) {
642
- accscalar_t sum_beta[ 8 ] = { accscalar_t (0 )} ;
643
- accscalar_t sum_gamma[ 8 ] = { accscalar_t (0 )} ;
644
+ accscalar_t sum_beta = accscalar_t (0 );
645
+ accscalar_t sum_gamma = accscalar_t (0 );
644
646
// row 0, 128, 256, ...896
645
647
auto row = tile_actual_row_base + slm_row * elements_per_thread_;
646
648
for (int i = 0 ; i < elements_per_thread_; i++) {
647
649
// row_local: row + 0, 8, 16, ...120
648
650
auto row_local = row + i * num_subgroup_;
649
651
auto actual_row = row_local + local_m;
650
- // TODO: tree reduction here for better acc
651
- if (actual_row < M_ && db_data_ != nullptr ) {
652
- sum_beta[i / 2 ] += static_cast <accscalar_t >(
653
- dY_data_[actual_row * N_ + actual_column]);
652
+ // TODO: try tree reduction here if accuracy loss
653
+ if (actual_row < M_) {
654
+ if constexpr (have_beta) {
655
+ sum_beta += static_cast <accscalar_t >(
656
+ dY_data_[actual_row * N_ + actual_column]);
657
+ }
658
+ if constexpr (have_gamma) {
659
+ sum_gamma += static_cast <accscalar_t >(
660
+ dY_data_[actual_row * N_ + actual_column]) *
661
+ (static_cast <accscalar_t >(
662
+ X_data_[actual_row * N_ + actual_column]) -
663
+ static_cast <accscalar_t >(mean_data_[actual_row])) *
664
+ static_cast <accscalar_t >(var_data_[actual_row]);
665
+ }
654
666
}
655
- if (actual_row < M_ && dg_data_ != nullptr ) {
656
- sum_gamma[i / 2 ] += static_cast <accscalar_t >(
657
- dY_data_[actual_row * N_ + actual_column]) *
658
- (static_cast <accscalar_t >(
659
- X_data_[actual_row * N_ + actual_column]) -
660
- static_cast <accscalar_t >(mean_data_[actual_row])) *
661
- static_cast <accscalar_t >(var_data_[actual_row]);
662
- }
663
- }
664
- for (int i = 0 ; i < 4 ; i++) {
665
- sum_beta[i] += sum_beta[i + 4 ];
666
- sum_gamma[i] += sum_gamma[i + 4 ];
667
667
}
668
668
669
669
local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
670
- ( sum_beta[ 0 ] + sum_beta[ 1 ]) + (sum_beta[ 2 ] + sum_beta[ 3 ]) ;
670
+ sum_beta;
671
671
local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
672
- ( sum_gamma[ 0 ] + sum_gamma[ 1 ]) + (sum_gamma[ 2 ] + sum_gamma[ 3 ]) ;
672
+ sum_gamma;
673
673
}
674
674
675
675
// item.barrier(sycl_local_fence);
676
- accscalar_t slm_sum_beta[ 4 ] = { accscalar_t (0 )} ;
677
- accscalar_t slm_sum_gamma[ 4 ] = { accscalar_t (0 )} ;
676
+ accscalar_t slm_sum_beta = accscalar_t (0 );
677
+ accscalar_t slm_sum_gamma = accscalar_t (0 );
678
678
// slm row 64, 8 subgroup, i = 0,2,4,6
679
679
// slm row 32, 8 subgroup, i = 0,2
680
680
// slm row 16, 8 subgroup, i = 0
681
681
for (int i = 0 ; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
682
- i = i + 2 ) {
683
- slm_sum_beta[i / 2 ] =
684
- local_sum_beta_
685
- [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n] +
686
- local_sum_beta_
687
- [((i + 1 ) * num_subgroup_ + local_m) * tile_size_n_ + local_n];
688
- slm_sum_gamma[i / 2 ] =
689
- local_sum_gamma_
690
- [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n] +
691
- local_sum_gamma_
692
- [((i + 1 ) * num_subgroup_ + local_m) * tile_size_n_ + local_n];
682
+ i = i + 1 ) {
683
+ slm_sum_beta += local_sum_beta_
684
+ [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
685
+ slm_sum_gamma += local_sum_gamma_
686
+ [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
693
687
}
694
- local_sum_beta_[local_m * tile_size_n_ + local_n] =
695
- (slm_sum_beta[0 ] + slm_sum_beta[1 ]) +
696
- (slm_sum_beta[2 ] + slm_sum_beta[3 ]);
697
- local_sum_gamma_[local_m * tile_size_n_ + local_n] =
698
- (slm_sum_gamma[0 ] + slm_sum_gamma[1 ]) +
699
- (slm_sum_gamma[2 ] + slm_sum_gamma[3 ]);
688
+ local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta;
689
+ local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma;
700
690
}
701
691
item.barrier (sycl_local_fence);
702
- accscalar_t output_sum_beta[ 4 ] = { accscalar_t (0 )} ;
703
- accscalar_t output_sum_gamma[ 4 ] = { accscalar_t (0 )} ;
692
+ accscalar_t output_sum_beta = accscalar_t (0 );
693
+ accscalar_t output_sum_gamma = accscalar_t (0 );
704
694
if (local_m == 0 && actual_column < N_) {
705
- for (int i = 0 ; i < num_subgroup_; i = i + 2 ) {
706
- output_sum_beta[i / 2 ] = local_sum_beta_[i * tile_size_n_ + local_n] +
707
- local_sum_beta_[(i + 1 ) * tile_size_n_ + local_n];
708
- output_sum_gamma[i / 2 ] = local_sum_gamma_[i * tile_size_n_ + local_n] +
709
- local_sum_gamma_[(i + 1 ) * tile_size_n_ + local_n];
695
+ for (int i = 0 ; i < num_subgroup_; i = i + 1 ) {
696
+ output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n];
697
+ output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n];
710
698
}
711
- if (db_data_ != nullptr )
699
+ if constexpr (have_beta) {
712
700
db_data_[tile_id_m * N_ + actual_column] =
713
- (static_cast <weight_t >(output_sum_beta[0 ]) +
714
- static_cast <weight_t >(output_sum_beta[1 ])) +
715
- (static_cast <weight_t >(output_sum_beta[2 ]) +
716
- static_cast <weight_t >(output_sum_beta[3 ]));
717
- if (dg_data_ != nullptr )
701
+ static_cast <weight_t >(output_sum_beta);
702
+ }
703
+
704
+ if constexpr (have_gamma) {
718
705
dg_data_[tile_id_m * N_ + actual_column] =
719
- (static_cast <weight_t >(output_sum_gamma[0 ]) +
720
- static_cast <weight_t >(output_sum_gamma[1 ])) +
721
- (static_cast <weight_t >(output_sum_gamma[2 ]) +
722
- static_cast <weight_t >(output_sum_gamma[3 ]));
706
+ static_cast <weight_t >(output_sum_gamma);
707
+ }
723
708
}
724
709
}
725
710
@@ -739,7 +724,6 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
739
724
const scalar_t * X_data,
740
725
weight_t * dg_block_data,
741
726
weight_t * db_block_data,
742
- int64_t num_tile_m,
743
727
int64_t num_tile_n,
744
728
int64_t tile_size_m,
745
729
int64_t tile_size_n,
@@ -753,7 +737,6 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
753
737
X_data_(X_data),
754
738
dg_data_(dg_block_data),
755
739
db_data_(db_block_data),
756
- num_tile_m_(num_tile_m),
757
740
num_tile_n_(num_tile_n),
758
741
tile_size_m_(tile_size_m),
759
742
tile_size_n_(tile_size_n),
@@ -771,7 +754,6 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
771
754
const scalar_t * X_data_;
772
755
weight_t * dg_data_;
773
756
weight_t * db_data_;
774
- int64_t num_tile_m_;
775
757
int64_t num_tile_n_;
776
758
int64_t tile_size_m_;
777
759
int64_t tile_size_n_;
@@ -1132,35 +1114,132 @@ void _layer_norm_backward_kernel(
1132
1114
}
1133
1115
1134
1116
size_t num_workgroup = num_tile_m * num_tile_n;
1135
- GammaBetaReduceFunctor<scalar_t , accscalar_t , mean_t , weight_t > kfn (
1136
- mean_data,
1137
- var_data,
1138
- dY_data,
1139
- X_data,
1140
- dgamma_blocks_ptr,
1141
- dbeta_blocks_ptr,
1142
- num_tile_m,
1143
- num_tile_n,
1144
- tile_size_m,
1145
- tile_size_n,
1146
- elements_per_thread,
1147
- local_size_x,
1148
- M,
1149
- N);
1150
-
1151
- sycl_kernel_submit<
1152
- GammaBetaReduceFunctor<scalar_t , accscalar_t , mean_t , weight_t >,
1153
- 3 >(
1154
- {num_workgroup, local_size_x, static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1155
- {1 , local_size_x, static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1156
- getCurrentSYCLQueue (),
1157
- kfn);
1158
- dgamma = dgamma_blocks.sum (0 );
1159
- dbeta = dbeta_blocks.sum (0 );
1117
+ if (dgamma.defined () && dbeta.defined ()) {
1118
+ GammaBetaReduceFunctor<
1119
+ scalar_t ,
1120
+ accscalar_t ,
1121
+ mean_t ,
1122
+ weight_t ,
1123
+ true ,
1124
+ true >
1125
+ kfn (mean_data,
1126
+ var_data,
1127
+ dY_data,
1128
+ X_data,
1129
+ dgamma_blocks_ptr,
1130
+ dbeta_blocks_ptr,
1131
+ num_tile_n,
1132
+ tile_size_m,
1133
+ tile_size_n,
1134
+ elements_per_thread,
1135
+ local_size_x,
1136
+ M,
1137
+ N);
1138
+
1139
+ sycl_kernel_submit<
1140
+ GammaBetaReduceFunctor<
1141
+ scalar_t ,
1142
+ accscalar_t ,
1143
+ mean_t ,
1144
+ weight_t ,
1145
+ true ,
1146
+ true >,
1147
+ 3 >(
1148
+ {num_workgroup,
1149
+ local_size_x,
1150
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1151
+ {1 ,
1152
+ local_size_x,
1153
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1154
+ getCurrentSYCLQueue (),
1155
+ kfn);
1156
+ dgamma = dgamma_blocks.sum (0 );
1157
+ dbeta = dbeta_blocks.sum (0 );
1158
+ } else if (dgamma.defined () && !dbeta.defined ()) {
1159
+ GammaBetaReduceFunctor<
1160
+ scalar_t ,
1161
+ accscalar_t ,
1162
+ mean_t ,
1163
+ weight_t ,
1164
+ true ,
1165
+ false >
1166
+ kfn (mean_data,
1167
+ var_data,
1168
+ dY_data,
1169
+ X_data,
1170
+ dgamma_blocks_ptr,
1171
+ dbeta_blocks_ptr,
1172
+ num_tile_n,
1173
+ tile_size_m,
1174
+ tile_size_n,
1175
+ elements_per_thread,
1176
+ local_size_x,
1177
+ M,
1178
+ N);
1179
+
1180
+ sycl_kernel_submit<
1181
+ GammaBetaReduceFunctor<
1182
+ scalar_t ,
1183
+ accscalar_t ,
1184
+ mean_t ,
1185
+ weight_t ,
1186
+ true ,
1187
+ false >,
1188
+ 3 >(
1189
+ {num_workgroup,
1190
+ local_size_x,
1191
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1192
+ {1 ,
1193
+ local_size_x,
1194
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1195
+ getCurrentSYCLQueue (),
1196
+ kfn);
1197
+ dgamma = dgamma_blocks.sum (0 );
1198
+ } else if (!dgamma.defined () && dbeta.defined ()) {
1199
+ GammaBetaReduceFunctor<
1200
+ scalar_t ,
1201
+ accscalar_t ,
1202
+ mean_t ,
1203
+ weight_t ,
1204
+ false ,
1205
+ true >
1206
+ kfn (mean_data,
1207
+ var_data,
1208
+ dY_data,
1209
+ X_data,
1210
+ dgamma_blocks_ptr,
1211
+ dbeta_blocks_ptr,
1212
+ num_tile_n,
1213
+ tile_size_m,
1214
+ tile_size_n,
1215
+ elements_per_thread,
1216
+ local_size_x,
1217
+ M,
1218
+ N);
1219
+
1220
+ sycl_kernel_submit<
1221
+ GammaBetaReduceFunctor<
1222
+ scalar_t ,
1223
+ accscalar_t ,
1224
+ mean_t ,
1225
+ weight_t ,
1226
+ false ,
1227
+ true >,
1228
+ 3 >(
1229
+ {num_workgroup,
1230
+ local_size_x,
1231
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1232
+ {1 ,
1233
+ local_size_x,
1234
+ static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1235
+ getCurrentSYCLQueue (),
1236
+ kfn);
1237
+ dbeta = dbeta_blocks.sum (0 );
1238
+ } else {
1239
+ return ;
1240
+ }
1160
1241
1161
1242
} else {
1162
- auto config_w = NormConfig (M, N, 0 , sizeof (scalar_t ));
1163
-
1164
1243
gamma_beta_bwd_simple_kernel<scalar_t , accscalar_t , mean_t , weight_t >(
1165
1244
dY, X, mean_data, var_data, dgamma, dbeta, config_w);
1166
1245
}
0 commit comments