Skip to content

Commit 484ffa6

Browse files
committed
save
1 parent c4930b6 commit 484ffa6

File tree

1 file changed

+148
-80
lines changed

1 file changed

+148
-80
lines changed

src/ATen/native/xpu/sycl/LayerNormKernels.cpp

Lines changed: 148 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -626,84 +626,96 @@ template <
626626
typename mean_t,
627627
typename weight_t>
628628
struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
629-
void operator()(sycl::nd_item<1> item) const {
630-
auto local_id = item.get_local_id(0); // [0, 255]
631-
auto group_id = item.get_group(0); // tile id
632-
633-
auto tile_row_base = group_id * tile_;
634-
635-
auto col = local_id % N_; // [0, N-1]
636-
auto lane = local_id / N_; // [0, 7]
637-
638-
// slm_row 0, 8, 16...56
639-
for (auto slm_row = 0; slm_row < tile_ / elements_per_thread_;
640-
slm_row += num_subgroup_) {
641-
accscalar_t sum_beta[8] = {accscalar_t(0)};
642-
accscalar_t sum_gamma[8] = {accscalar_t(0)};
643-
// row: tile_row_base + 0, 128, 256, ...896
644-
auto row = tile_row_base + slm_row * elements_per_thread_;
645-
for (int i = 0; i < elements_per_thread_; i++) {
646-
// row_local: row + 0, 8, 16, ...120
647-
auto row_local = row + i * num_subgroup_;
648-
auto actual_row = row_local + lane;
649-
// TODO: tree reduction here for better acc
650-
if (actual_row < M_ && db_data_ != nullptr) {
651-
sum_beta[i / 2] +=
652-
static_cast<accscalar_t>(dY_data_[actual_row * N_ + col]);
629+
void operator()(sycl::nd_item<3> item) const {
630+
auto local_n = item.get_local_id(2); // [0, 32)
631+
auto local_m = item.get_local_id(1); // [0, 8)
632+
auto tile_id = item.get_global_id(0); // tile id
633+
auto tile_id_n = tile_id % num_tile_n_;
634+
auto tile_id_m = tile_id / num_tile_n_;
635+
auto tile_actual_row_base = tile_id_m * tile_size_m_;
636+
auto tile_actual_col_base = tile_id_n * tile_size_n_;
637+
auto actual_column = tile_actual_col_base + local_n;
638+
if (actual_column < N_) {
639+
// slm_row 0, 8, 16...56
640+
for (auto slm_row = 0; slm_row < tile_size_m_ / elements_per_thread_;
641+
slm_row += num_subgroup_) {
642+
accscalar_t sum_beta[8] = {accscalar_t(0)};
643+
accscalar_t sum_gamma[8] = {accscalar_t(0)};
644+
// row 0, 128, 256, ...896
645+
auto row = tile_actual_row_base + slm_row * elements_per_thread_;
646+
for (int i = 0; i < elements_per_thread_; i++) {
647+
// row_local: row + 0, 8, 16, ...120
648+
auto row_local = row + i * num_subgroup_;
649+
auto actual_row = row_local + local_m;
650+
// TODO: tree reduction here for better acc
651+
if (actual_row < M_ && db_data_ != nullptr) {
652+
sum_beta[i / 2] += static_cast<accscalar_t>(
653+
dY_data_[actual_row * N_ + actual_column]);
654+
}
655+
if (actual_row < M_ && dg_data_ != nullptr) {
656+
sum_gamma[i / 2] += static_cast<accscalar_t>(
657+
dY_data_[actual_row * N_ + actual_column]) *
658+
(static_cast<accscalar_t>(
659+
X_data_[actual_row * N_ + actual_column]) -
660+
static_cast<accscalar_t>(mean_data_[actual_row])) *
661+
static_cast<accscalar_t>(var_data_[actual_row]);
662+
}
653663
}
654-
if (actual_row < M_ && dg_data_ != nullptr) {
655-
sum_gamma[i / 2] +=
656-
static_cast<accscalar_t>(dY_data_[actual_row * N_ + col]) *
657-
(static_cast<accscalar_t>(X_data_[actual_row * N_ + col]) -
658-
static_cast<accscalar_t>(mean_data_[actual_row])) *
659-
static_cast<accscalar_t>(var_data_[actual_row]);
664+
for (int i = 0; i < 4; i++) {
665+
sum_beta[i] += sum_beta[i + 4];
666+
sum_gamma[i] += sum_gamma[i + 4];
660667
}
661-
}
662-
for (int i = 0; i < 4; i++) {
663-
sum_beta[i] += sum_beta[i + 4];
664-
sum_gamma[i] += sum_gamma[i + 4];
665-
}
666668

667-
local_sum_beta_[slm_row * N_ + local_id] =
668-
(sum_beta[0] + sum_beta[1]) + (sum_beta[2] + sum_beta[3]);
669-
local_sum_gamma_[slm_row * N_ + local_id] =
670-
(sum_gamma[0] + sum_gamma[1]) + (sum_gamma[2] + sum_gamma[3]);
671-
}
672-
item.barrier(sycl_local_fence);
669+
local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
670+
(sum_beta[0] + sum_beta[1]) + (sum_beta[2] + sum_beta[3]);
671+
local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
672+
(sum_gamma[0] + sum_gamma[1]) + (sum_gamma[2] + sum_gamma[3]);
673+
}
673674

674-
accscalar_t slm_sum_beta[4] = {accscalar_t(0)};
675-
accscalar_t slm_sum_gamma[4] = {accscalar_t(0)};
676-
for (int i = 0; i < tile_ / elements_per_thread_ / num_subgroup_;
677-
i = i + 2) {
678-
slm_sum_beta[i / 2] =
679-
local_sum_beta_[(i * num_subgroup_ + lane) * N_ + col] +
680-
local_sum_beta_[((i + 1) * num_subgroup_ + lane) * N_ + col];
681-
slm_sum_gamma[i / 2] =
682-
local_sum_gamma_[(i * num_subgroup_ + lane) * N_ + col] +
683-
local_sum_gamma_[((i + 1) * num_subgroup_ + lane) * N_ + col];
675+
// item.barrier(sycl_local_fence);
676+
accscalar_t slm_sum_beta[4] = {accscalar_t(0)};
677+
accscalar_t slm_sum_gamma[4] = {accscalar_t(0)};
678+
// slm row 64, 8 subgroup, i = 0,2,4,6
679+
// slm row 32, 8 subgroup, i = 0,2
680+
// slm row 16, 8 subgroup, i = 0
681+
for (int i = 0; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
682+
i = i + 2) {
683+
slm_sum_beta[i / 2] =
684+
local_sum_beta_
685+
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n] +
686+
local_sum_beta_
687+
[((i + 1) * num_subgroup_ + local_m) * tile_size_n_ + local_n];
688+
slm_sum_gamma[i / 2] =
689+
local_sum_gamma_
690+
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n] +
691+
local_sum_gamma_
692+
[((i + 1) * num_subgroup_ + local_m) * tile_size_n_ + local_n];
693+
}
694+
local_sum_beta_[local_m * tile_size_n_ + local_n] =
695+
(slm_sum_beta[0] + slm_sum_beta[1]) +
696+
(slm_sum_beta[2] + slm_sum_beta[3]);
697+
local_sum_gamma_[local_m * tile_size_n_ + local_n] =
698+
(slm_sum_gamma[0] + slm_sum_gamma[1]) +
699+
(slm_sum_gamma[2] + slm_sum_gamma[3]);
684700
}
685-
local_sum_beta_[lane * N_ + col] = (slm_sum_beta[0] + slm_sum_beta[1]) +
686-
(slm_sum_beta[2] + slm_sum_beta[3]);
687-
local_sum_gamma_[lane * N_ + col] = (slm_sum_gamma[0] + slm_sum_gamma[1]) +
688-
(slm_sum_gamma[2] + slm_sum_gamma[3]);
689701
item.barrier(sycl_local_fence);
690702
accscalar_t output_sum_beta[4] = {accscalar_t(0)};
691703
accscalar_t output_sum_gamma[4] = {accscalar_t(0)};
692-
if (local_id < N_) {
704+
if (local_m == 0 && actual_column < N_) {
693705
for (int i = 0; i < num_subgroup_; i = i + 2) {
694-
output_sum_beta[i / 2] =
695-
local_sum_beta_[i * N_ + col] + local_sum_beta_[(i + 1) * N_ + col];
696-
output_sum_gamma[i / 2] = local_sum_gamma_[i * N_ + col] +
697-
local_sum_gamma_[(i + 1) * N_ + col];
706+
output_sum_beta[i / 2] = local_sum_beta_[i * tile_size_n_ + local_n] +
707+
local_sum_beta_[(i + 1) * tile_size_n_ + local_n];
708+
output_sum_gamma[i / 2] = local_sum_gamma_[i * tile_size_n_ + local_n] +
709+
local_sum_gamma_[(i + 1) * tile_size_n_ + local_n];
698710
}
699711
if (db_data_ != nullptr)
700-
db_data_[group_id * N_ + col] =
712+
db_data_[tile_id_m * tile_size_n_ + actual_column] =
701713
(static_cast<weight_t>(output_sum_beta[0]) +
702714
static_cast<weight_t>(output_sum_beta[1])) +
703715
(static_cast<weight_t>(output_sum_beta[2]) +
704716
static_cast<weight_t>(output_sum_beta[3]));
705717
if (dg_data_ != nullptr)
706-
dg_data_[group_id * N_ + col] =
718+
dg_data_[tile_id_m * tile_size_n_ + actual_column] =
707719
(static_cast<weight_t>(output_sum_gamma[0]) +
708720
static_cast<weight_t>(output_sum_gamma[1])) +
709721
(static_cast<weight_t>(output_sum_gamma[2]) +
@@ -713,9 +725,11 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
713725

714726
void sycl_ker_config_convention(sycl::handler& cgh) {
715727
local_sum_beta_ = sycl_local_acc_t<accscalar_t, 1>(
716-
sycl::range<1>(N_ * tile_ / elements_per_thread_), cgh);
728+
sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_),
729+
cgh);
717730
local_sum_gamma_ = sycl_local_acc_t<accscalar_t, 1>(
718-
sycl::range<1>(N_ * tile_ / elements_per_thread_), cgh);
731+
sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_),
732+
cgh);
719733
}
720734

721735
GammaBetaReduceFunctor(
@@ -725,7 +739,10 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
725739
const scalar_t* X_data,
726740
weight_t* dg_block_data,
727741
weight_t* db_block_data,
728-
int64_t tile,
742+
int64_t num_tile_m,
743+
int64_t num_tile_n,
744+
int64_t tile_size_m,
745+
int64_t tile_size_n,
729746
int64_t elements_per_thread,
730747
int64_t num_subgroup,
731748
int64_t M,
@@ -736,7 +753,10 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
736753
X_data_(X_data),
737754
dg_data_(dg_block_data),
738755
db_data_(db_block_data),
739-
tile_(tile),
756+
num_tile_m_(num_tile_m),
757+
num_tile_n_(num_tile_n),
758+
tile_size_m_(tile_size_m),
759+
tile_size_n_(tile_size_n),
740760
elements_per_thread_(elements_per_thread),
741761
num_subgroup_(num_subgroup),
742762
M_(M),
@@ -751,7 +771,10 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
751771
const scalar_t* X_data_;
752772
weight_t* dg_data_;
753773
weight_t* db_data_;
754-
int64_t tile_;
774+
int64_t num_tile_m_;
775+
int64_t num_tile_n_;
776+
int64_t tile_size_m_;
777+
int64_t tile_size_n_;
755778
int64_t elements_per_thread_;
756779
int64_t num_subgroup_;
757780
int64_t M_;
@@ -1040,12 +1063,52 @@ void _layer_norm_backward_kernel(
10401063
norm, config, can_use_32bit_index);
10411064
}
10421065
}
1043-
if (N <= 32 && M > 64 * 1024) {
1044-
const int num_subgroup = 8;
1045-
const int workgroup_size = N * num_subgroup;
1046-
const int tile_size = 1024;
1066+
auto config_w = NormConfig(M, N, 0, sizeof(scalar_t));
1067+
auto norm_config_global_size =
1068+
config_w.workgroup_num * config_w.block_row * config_w.workgroup_size;
1069+
int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU();
1070+
// use two stage col reduction if norm config occupancy < 50%
1071+
// TODO: we can releax this restriction in future for better perf
1072+
bool use_two_stage_col_reduction =
1073+
(dY.dtype() == kFloat || dY.dtype() == kBFloat16 ||
1074+
dY.dtype() == kHalf) &&
1075+
norm_config_global_size / syclMaxSubGroupSize() * 2 <= thread_slots;
1076+
// cuda uses condition M > 64 * 1024 && N / 32 < sm_count / 2 to parallelize
1077+
// in the M dimension
1078+
if (use_two_stage_col_reduction && M > 64 * 1024 &&
1079+
N / 32 < syclGpuEuCount() / syclGpuEUCountPerSubslice() / 2) {
1080+
const size_t local_size_x = 8;
1081+
const size_t SIMD = 32;
1082+
// workgroup size is 256
1083+
// slm is 16KB, 64*32 float * 2
1084+
// elements_per_thread is at least 16
10471085
const int elements_per_thread = 16;
1048-
const int num_tile = (M + tile_size - 1) / tile_size;
1086+
int tile_size_m = 1024;
1087+
int tile_size_n = N < 32 ? N : 32;
1088+
int num_tile_m = (M + tile_size_m - 1) / tile_size_m;
1089+
int num_tile_n = (N + tile_size_n - 1) / tile_size_n;
1090+
bool adjust_m = true;
1091+
// for M = 64*1024, N = 1, we choose tile size (256, 16) on pvc
1092+
// TODO: we can tune these conditions in future
1093+
for (auto i = 0; i < 3; i++) {
1094+
// occupancy <= 50%
1095+
if (num_tile_m * num_tile_n * local_size_x * SIMD /
1096+
syclMaxSubGroupSize() * 2 <=
1097+
thread_slots) {
1098+
if (adjust_m) {
1099+
tile_size_m /= 2;
1100+
num_tile_m = (M + tile_size_m - 1) / tile_size_m;
1101+
adjust_m = false;
1102+
} else {
1103+
tile_size_n /= 2;
1104+
num_tile_n = (N + tile_size_n - 1) / tile_size_n;
1105+
adjust_m = true;
1106+
}
1107+
} else {
1108+
break;
1109+
}
1110+
}
1111+
// tile size can be (1024,32), (512,32), (512,16), (256, 16)
10491112
// Change these parameters will cause changes in kernel
10501113
const scalar_t* dY_data = dY.const_data_ptr<scalar_t>();
10511114
const scalar_t* X_data = X.const_data_ptr<scalar_t>();
@@ -1059,32 +1122,37 @@ void _layer_norm_backward_kernel(
10591122
if (dgamma.defined()) {
10601123
auto options = dgamma.options();
10611124
// TODO: how to set dgamma_blocks dtype = float32?
1062-
dgamma_blocks = at::empty({num_tile, N}, options);
1125+
dgamma_blocks = at::empty({num_tile_m, N}, options);
10631126
dgamma_blocks_ptr = dgamma_blocks.data_ptr<weight_t>();
10641127
}
10651128
if (dbeta.defined()) {
10661129
auto options = dbeta.options();
1067-
dbeta_blocks = at::empty({num_tile, N}, options);
1130+
dbeta_blocks = at::empty({num_tile_m, N}, options);
10681131
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>();
10691132
}
10701133

1071-
int num_workgroup = (M + tile_size - 1) / tile_size;
1134+
size_t num_workgroup = num_tile_m * num_tile_n;
10721135
GammaBetaReduceFunctor<scalar_t, accscalar_t, mean_t, weight_t> kfn(
10731136
mean_data,
10741137
var_data,
10751138
dY_data,
10761139
X_data,
10771140
dgamma_blocks_ptr,
10781141
dbeta_blocks_ptr,
1079-
tile_size,
1142+
num_tile_m,
1143+
num_tile_n,
1144+
tile_size_m,
1145+
tile_size_n,
10801146
elements_per_thread,
1081-
num_subgroup,
1147+
local_size_x,
10821148
M,
10831149
N);
10841150

1085-
sycl_kernel_submit(
1086-
workgroup_size * num_workgroup,
1087-
workgroup_size,
1151+
sycl_kernel_submit<
1152+
GammaBetaReduceFunctor<scalar_t, accscalar_t, mean_t, weight_t>,
1153+
3>(
1154+
{num_workgroup, local_size_x, static_cast<size_t>(N < SIMD ? N : SIMD)},
1155+
{1, local_size_x, static_cast<size_t>(N < SIMD ? N : SIMD)},
10881156
getCurrentSYCLQueue(),
10891157
kfn);
10901158
dgamma = dgamma_blocks.sum(0);

0 commit comments

Comments
 (0)