Skip to content

Commit 0121231

Browse files
committed
save
1 parent 2001004 commit 0121231

File tree

1 file changed

+78
-68
lines changed

1 file changed

+78
-68
lines changed

src/ATen/native/xpu/sycl/LayerNormKernels.cpp

Lines changed: 78 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -631,79 +631,82 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
631631
void operator()(sycl::nd_item<3> item) const {
632632
auto local_n = item.get_local_id(2); // [0, 32)
633633
auto local_m = item.get_local_id(1); // [0, 8)
634-
auto tile_id = item.get_global_id(0); // tile id
635-
auto tile_id_n = tile_id % num_tile_n_;
636-
auto tile_id_m = tile_id / num_tile_n_;
637-
auto tile_actual_row_base = tile_id_m * tile_size_m_;
638-
auto tile_actual_col_base = tile_id_n * tile_size_n_;
639-
auto actual_column = tile_actual_col_base + local_n;
640-
if (actual_column < N_) {
641-
// slm_row 0, 8, 16...56
642-
for (auto slm_row = 0; slm_row < tile_size_m_ / elements_per_thread_;
643-
slm_row += num_subgroup_) {
644-
accscalar_t sum_beta = accscalar_t(0);
645-
accscalar_t sum_gamma = accscalar_t(0);
646-
// row 0, 128, 256, ...896
647-
auto row = tile_actual_row_base + slm_row * elements_per_thread_;
648-
for (int i = 0; i < elements_per_thread_; i++) {
649-
// row_local: row + 0, 8, 16, ...120
650-
auto row_local = row + i * num_subgroup_;
651-
auto actual_row = row_local + local_m;
652-
// TODO: try tree reduction here if accuracy loss
653-
if (actual_row < M_) {
654-
if constexpr (have_beta) {
655-
sum_beta += static_cast<accscalar_t>(
656-
dY_data_[actual_row * N_ + actual_column]);
657-
}
658-
if constexpr (have_gamma) {
659-
sum_gamma += static_cast<accscalar_t>(
660-
dY_data_[actual_row * N_ + actual_column]) *
661-
(static_cast<accscalar_t>(
662-
X_data_[actual_row * N_ + actual_column]) -
663-
static_cast<accscalar_t>(mean_data_[actual_row])) *
664-
static_cast<accscalar_t>(var_data_[actual_row]);
634+
for (auto tile_id = item.get_global_id(0);
635+
tile_id < num_tile_n_ * num_tile_m_;
636+
tile_id += item.get_group_range(0)) {
637+
auto tile_id_n = tile_id % num_tile_n_;
638+
auto tile_id_m = tile_id / num_tile_n_;
639+
auto tile_actual_row_base = tile_id_m * tile_size_m_;
640+
auto tile_actual_col_base = tile_id_n * tile_size_n_;
641+
auto actual_column = tile_actual_col_base + local_n;
642+
if (actual_column < N_) {
643+
// slm_row 0, 8, 16...56
644+
for (auto slm_row = 0; slm_row < tile_size_m_ / elements_per_thread_;
645+
slm_row += num_subgroup_) {
646+
accscalar_t sum_beta = accscalar_t(0);
647+
accscalar_t sum_gamma = accscalar_t(0);
648+
// row 0, 128, 256, ...896
649+
auto row = tile_actual_row_base + slm_row * elements_per_thread_;
650+
for (int i = 0; i < elements_per_thread_; i++) {
651+
// row_local: row + 0, 8, 16, ...120
652+
auto row_local = row + i * num_subgroup_;
653+
auto actual_row = row_local + local_m;
654+
// TODO: try tree reduction here if accuracy loss
655+
if (actual_row < M_) {
656+
if constexpr (have_beta) {
657+
sum_beta += static_cast<accscalar_t>(
658+
dY_data_[actual_row * N_ + actual_column]);
659+
}
660+
if constexpr (have_gamma) {
661+
sum_gamma += static_cast<accscalar_t>(
662+
dY_data_[actual_row * N_ + actual_column]) *
663+
(static_cast<accscalar_t>(
664+
X_data_[actual_row * N_ + actual_column]) -
665+
static_cast<accscalar_t>(mean_data_[actual_row])) *
666+
static_cast<accscalar_t>(var_data_[actual_row]);
667+
}
665668
}
666669
}
667-
}
668670

669-
local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
670-
sum_beta;
671-
local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
672-
sum_gamma;
673-
}
671+
local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
672+
sum_beta;
673+
local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
674+
sum_gamma;
675+
}
674676

675-
// item.barrier(sycl_local_fence);
676-
accscalar_t slm_sum_beta = accscalar_t(0);
677-
accscalar_t slm_sum_gamma = accscalar_t(0);
678-
// slm row 64, 8 subgroup, i = 0,2,4,6
679-
// slm row 32, 8 subgroup, i = 0,2
680-
// slm row 16, 8 subgroup, i = 0
681-
for (int i = 0; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
682-
i = i + 1) {
683-
slm_sum_beta += local_sum_beta_
684-
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
685-
slm_sum_gamma += local_sum_gamma_
686-
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
687-
}
688-
local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta;
689-
local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma;
690-
}
691-
item.barrier(sycl_local_fence);
692-
accscalar_t output_sum_beta = accscalar_t(0);
693-
accscalar_t output_sum_gamma = accscalar_t(0);
694-
if (local_m == 0 && actual_column < N_) {
695-
for (int i = 0; i < num_subgroup_; i = i + 1) {
696-
output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n];
697-
output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n];
698-
}
699-
if constexpr (have_beta) {
700-
db_data_[tile_id_m * N_ + actual_column] =
701-
static_cast<weight_t>(output_sum_beta);
677+
// item.barrier(sycl_local_fence);
678+
accscalar_t slm_sum_beta = accscalar_t(0);
679+
accscalar_t slm_sum_gamma = accscalar_t(0);
680+
// slm row 64, 8 subgroup, i = 0,2,4,6
681+
// slm row 32, 8 subgroup, i = 0,2
682+
// slm row 16, 8 subgroup, i = 0
683+
for (int i = 0; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
684+
i = i + 1) {
685+
slm_sum_beta += local_sum_beta_
686+
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
687+
slm_sum_gamma += local_sum_gamma_
688+
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
689+
}
690+
local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta;
691+
local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma;
702692
}
693+
item.barrier(sycl_local_fence);
694+
accscalar_t output_sum_beta = accscalar_t(0);
695+
accscalar_t output_sum_gamma = accscalar_t(0);
696+
if (local_m == 0 && actual_column < N_) {
697+
for (int i = 0; i < num_subgroup_; i = i + 1) {
698+
output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n];
699+
output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n];
700+
}
701+
if constexpr (have_beta) {
702+
db_data_[tile_id_m * N_ + actual_column] =
703+
static_cast<weight_t>(output_sum_beta);
704+
}
703705

704-
if constexpr (have_gamma) {
705-
dg_data_[tile_id_m * N_ + actual_column] =
706-
static_cast<weight_t>(output_sum_gamma);
706+
if constexpr (have_gamma) {
707+
dg_data_[tile_id_m * N_ + actual_column] =
708+
static_cast<weight_t>(output_sum_gamma);
709+
}
707710
}
708711
}
709712
}
@@ -724,6 +727,7 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
724727
const scalar_t* X_data,
725728
weight_t* dg_block_data,
726729
weight_t* db_block_data,
730+
int64_t num_tile_m,
727731
int64_t num_tile_n,
728732
int64_t tile_size_m,
729733
int64_t tile_size_n,
@@ -737,6 +741,7 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
737741
X_data_(X_data),
738742
dg_data_(dg_block_data),
739743
db_data_(db_block_data),
744+
num_tile_m_(num_tile_m),
740745
num_tile_n_(num_tile_n),
741746
tile_size_m_(tile_size_m),
742747
tile_size_n_(tile_size_n),
@@ -754,6 +759,7 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
754759
const scalar_t* X_data_;
755760
weight_t* dg_data_;
756761
weight_t* db_data_;
762+
int64_t num_tile_m_;
757763
int64_t num_tile_n_;
758764
int64_t tile_size_m_;
759765
int64_t tile_size_n_;
@@ -1113,7 +1119,8 @@ void _layer_norm_backward_kernel(
11131119
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>();
11141120
}
11151121

1116-
size_t num_workgroup = num_tile_m * num_tile_n;
1122+
size_t num_workgroup =
1123+
std::min(num_tile_m * num_tile_n, static_cast<int>(thread_slots / local_size_x));
11171124
if (dgamma.defined() && dbeta.defined()) {
11181125
GammaBetaReduceFunctor<
11191126
scalar_t,
@@ -1128,6 +1135,7 @@ void _layer_norm_backward_kernel(
11281135
X_data,
11291136
dgamma_blocks_ptr,
11301137
dbeta_blocks_ptr,
1138+
num_tile_m,
11311139
num_tile_n,
11321140
tile_size_m,
11331141
tile_size_n,
@@ -1169,6 +1177,7 @@ void _layer_norm_backward_kernel(
11691177
X_data,
11701178
dgamma_blocks_ptr,
11711179
dbeta_blocks_ptr,
1180+
num_tile_m,
11721181
num_tile_n,
11731182
tile_size_m,
11741183
tile_size_n,
@@ -1209,6 +1218,7 @@ void _layer_norm_backward_kernel(
12091218
X_data,
12101219
dgamma_blocks_ptr,
12111220
dbeta_blocks_ptr,
1221+
num_tile_m,
12121222
num_tile_n,
12131223
tile_size_m,
12141224
tile_size_n,

0 commit comments

Comments
 (0)