@@ -631,79 +631,82 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
631
631
void operator ()(sycl::nd_item<3 > item) const {
632
632
auto local_n = item.get_local_id (2 ); // [0, 32)
633
633
auto local_m = item.get_local_id (1 ); // [0, 8)
634
- auto tile_id = item.get_global_id (0 ); // tile id
635
- auto tile_id_n = tile_id % num_tile_n_;
636
- auto tile_id_m = tile_id / num_tile_n_;
637
- auto tile_actual_row_base = tile_id_m * tile_size_m_;
638
- auto tile_actual_col_base = tile_id_n * tile_size_n_;
639
- auto actual_column = tile_actual_col_base + local_n;
640
- if (actual_column < N_) {
641
- // slm_row 0, 8, 16...56
642
- for (auto slm_row = 0 ; slm_row < tile_size_m_ / elements_per_thread_;
643
- slm_row += num_subgroup_) {
644
- accscalar_t sum_beta = accscalar_t (0 );
645
- accscalar_t sum_gamma = accscalar_t (0 );
646
- // row 0, 128, 256, ...896
647
- auto row = tile_actual_row_base + slm_row * elements_per_thread_;
648
- for (int i = 0 ; i < elements_per_thread_; i++) {
649
- // row_local: row + 0, 8, 16, ...120
650
- auto row_local = row + i * num_subgroup_;
651
- auto actual_row = row_local + local_m;
652
- // TODO: try tree reduction here if accuracy loss
653
- if (actual_row < M_) {
654
- if constexpr (have_beta) {
655
- sum_beta += static_cast <accscalar_t >(
656
- dY_data_[actual_row * N_ + actual_column]);
657
- }
658
- if constexpr (have_gamma) {
659
- sum_gamma += static_cast <accscalar_t >(
660
- dY_data_[actual_row * N_ + actual_column]) *
661
- (static_cast <accscalar_t >(
662
- X_data_[actual_row * N_ + actual_column]) -
663
- static_cast <accscalar_t >(mean_data_[actual_row])) *
664
- static_cast <accscalar_t >(var_data_[actual_row]);
634
+ for (auto tile_id = item.get_global_id (0 );
635
+ tile_id < num_tile_n_ * num_tile_m_;
636
+ tile_id += item.get_group_range (0 )) {
637
+ auto tile_id_n = tile_id % num_tile_n_;
638
+ auto tile_id_m = tile_id / num_tile_n_;
639
+ auto tile_actual_row_base = tile_id_m * tile_size_m_;
640
+ auto tile_actual_col_base = tile_id_n * tile_size_n_;
641
+ auto actual_column = tile_actual_col_base + local_n;
642
+ if (actual_column < N_) {
643
+ // slm_row 0, 8, 16...56
644
+ for (auto slm_row = 0 ; slm_row < tile_size_m_ / elements_per_thread_;
645
+ slm_row += num_subgroup_) {
646
+ accscalar_t sum_beta = accscalar_t (0 );
647
+ accscalar_t sum_gamma = accscalar_t (0 );
648
+ // row 0, 128, 256, ...896
649
+ auto row = tile_actual_row_base + slm_row * elements_per_thread_;
650
+ for (int i = 0 ; i < elements_per_thread_; i++) {
651
+ // row_local: row + 0, 8, 16, ...120
652
+ auto row_local = row + i * num_subgroup_;
653
+ auto actual_row = row_local + local_m;
654
+ // TODO: try tree reduction here if accuracy loss
655
+ if (actual_row < M_) {
656
+ if constexpr (have_beta) {
657
+ sum_beta += static_cast <accscalar_t >(
658
+ dY_data_[actual_row * N_ + actual_column]);
659
+ }
660
+ if constexpr (have_gamma) {
661
+ sum_gamma += static_cast <accscalar_t >(
662
+ dY_data_[actual_row * N_ + actual_column]) *
663
+ (static_cast <accscalar_t >(
664
+ X_data_[actual_row * N_ + actual_column]) -
665
+ static_cast <accscalar_t >(mean_data_[actual_row])) *
666
+ static_cast <accscalar_t >(var_data_[actual_row]);
667
+ }
665
668
}
666
669
}
667
- }
668
670
669
- local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
670
- sum_beta;
671
- local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
672
- sum_gamma;
673
- }
671
+ local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
672
+ sum_beta;
673
+ local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
674
+ sum_gamma;
675
+ }
674
676
675
- // item.barrier(sycl_local_fence);
676
- accscalar_t slm_sum_beta = accscalar_t (0 );
677
- accscalar_t slm_sum_gamma = accscalar_t (0 );
678
- // slm row 64, 8 subgroup, i = 0,2,4,6
679
- // slm row 32, 8 subgroup, i = 0,2
680
- // slm row 16, 8 subgroup, i = 0
681
- for (int i = 0 ; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
682
- i = i + 1 ) {
683
- slm_sum_beta += local_sum_beta_
684
- [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
685
- slm_sum_gamma += local_sum_gamma_
686
- [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
687
- }
688
- local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta;
689
- local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma;
690
- }
691
- item.barrier (sycl_local_fence);
692
- accscalar_t output_sum_beta = accscalar_t (0 );
693
- accscalar_t output_sum_gamma = accscalar_t (0 );
694
- if (local_m == 0 && actual_column < N_) {
695
- for (int i = 0 ; i < num_subgroup_; i = i + 1 ) {
696
- output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n];
697
- output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n];
698
- }
699
- if constexpr (have_beta) {
700
- db_data_[tile_id_m * N_ + actual_column] =
701
- static_cast <weight_t >(output_sum_beta);
677
+ // item.barrier(sycl_local_fence);
678
+ accscalar_t slm_sum_beta = accscalar_t (0 );
679
+ accscalar_t slm_sum_gamma = accscalar_t (0 );
680
+ // slm row 64, 8 subgroup, i = 0,2,4,6
681
+ // slm row 32, 8 subgroup, i = 0,2
682
+ // slm row 16, 8 subgroup, i = 0
683
+ for (int i = 0 ; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
684
+ i = i + 1 ) {
685
+ slm_sum_beta += local_sum_beta_
686
+ [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
687
+ slm_sum_gamma += local_sum_gamma_
688
+ [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
689
+ }
690
+ local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta;
691
+ local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma;
702
692
}
693
+ item.barrier (sycl_local_fence);
694
+ accscalar_t output_sum_beta = accscalar_t (0 );
695
+ accscalar_t output_sum_gamma = accscalar_t (0 );
696
+ if (local_m == 0 && actual_column < N_) {
697
+ for (int i = 0 ; i < num_subgroup_; i = i + 1 ) {
698
+ output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n];
699
+ output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n];
700
+ }
701
+ if constexpr (have_beta) {
702
+ db_data_[tile_id_m * N_ + actual_column] =
703
+ static_cast <weight_t >(output_sum_beta);
704
+ }
703
705
704
- if constexpr (have_gamma) {
705
- dg_data_[tile_id_m * N_ + actual_column] =
706
- static_cast <weight_t >(output_sum_gamma);
706
+ if constexpr (have_gamma) {
707
+ dg_data_[tile_id_m * N_ + actual_column] =
708
+ static_cast <weight_t >(output_sum_gamma);
709
+ }
707
710
}
708
711
}
709
712
}
@@ -724,6 +727,7 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
724
727
const scalar_t * X_data,
725
728
weight_t * dg_block_data,
726
729
weight_t * db_block_data,
730
+ int64_t num_tile_m,
727
731
int64_t num_tile_n,
728
732
int64_t tile_size_m,
729
733
int64_t tile_size_n,
@@ -737,6 +741,7 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
737
741
X_data_(X_data),
738
742
dg_data_(dg_block_data),
739
743
db_data_(db_block_data),
744
+ num_tile_m_(num_tile_m),
740
745
num_tile_n_(num_tile_n),
741
746
tile_size_m_(tile_size_m),
742
747
tile_size_n_(tile_size_n),
@@ -754,6 +759,7 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
754
759
const scalar_t * X_data_;
755
760
weight_t * dg_data_;
756
761
weight_t * db_data_;
762
+ int64_t num_tile_m_;
757
763
int64_t num_tile_n_;
758
764
int64_t tile_size_m_;
759
765
int64_t tile_size_n_;
@@ -1113,7 +1119,8 @@ void _layer_norm_backward_kernel(
1113
1119
dbeta_blocks_ptr = dbeta_blocks.data_ptr <weight_t >();
1114
1120
}
1115
1121
1116
- size_t num_workgroup = num_tile_m * num_tile_n;
1122
+ size_t num_workgroup =
1123
+ std::min (num_tile_m * num_tile_n, static_cast <int >(thread_slots / local_size_x));
1117
1124
if (dgamma.defined () && dbeta.defined ()) {
1118
1125
GammaBetaReduceFunctor<
1119
1126
scalar_t ,
@@ -1128,6 +1135,7 @@ void _layer_norm_backward_kernel(
1128
1135
X_data,
1129
1136
dgamma_blocks_ptr,
1130
1137
dbeta_blocks_ptr,
1138
+ num_tile_m,
1131
1139
num_tile_n,
1132
1140
tile_size_m,
1133
1141
tile_size_n,
@@ -1169,6 +1177,7 @@ void _layer_norm_backward_kernel(
1169
1177
X_data,
1170
1178
dgamma_blocks_ptr,
1171
1179
dbeta_blocks_ptr,
1180
+ num_tile_m,
1172
1181
num_tile_n,
1173
1182
tile_size_m,
1174
1183
tile_size_n,
@@ -1209,6 +1218,7 @@ void _layer_norm_backward_kernel(
1209
1218
X_data,
1210
1219
dgamma_blocks_ptr,
1211
1220
dbeta_blocks_ptr,
1221
+ num_tile_m,
1212
1222
num_tile_n,
1213
1223
tile_size_m,
1214
1224
tile_size_n,
0 commit comments