Skip to content

Commit 2001004

Browse files
committed
save
1 parent c436b11 commit 2001004

File tree

1 file changed

+166
-87
lines changed

1 file changed

+166
-87
lines changed

src/ATen/native/xpu/sycl/LayerNormKernels.cpp

Lines changed: 166 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,9 @@ template <
624624
typename scalar_t,
625625
typename accscalar_t,
626626
typename mean_t,
627-
typename weight_t>
627+
typename weight_t,
628+
bool have_gamma = true,
629+
bool have_beta = true>
628630
struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
629631
void operator()(sycl::nd_item<3> item) const {
630632
auto local_n = item.get_local_id(2); // [0, 32)
@@ -639,87 +641,70 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
639641
// slm_row 0, 8, 16...56
640642
for (auto slm_row = 0; slm_row < tile_size_m_ / elements_per_thread_;
641643
slm_row += num_subgroup_) {
642-
accscalar_t sum_beta[8] = {accscalar_t(0)};
643-
accscalar_t sum_gamma[8] = {accscalar_t(0)};
644+
accscalar_t sum_beta = accscalar_t(0);
645+
accscalar_t sum_gamma = accscalar_t(0);
644646
// row 0, 128, 256, ...896
645647
auto row = tile_actual_row_base + slm_row * elements_per_thread_;
646648
for (int i = 0; i < elements_per_thread_; i++) {
647649
// row_local: row + 0, 8, 16, ...120
648650
auto row_local = row + i * num_subgroup_;
649651
auto actual_row = row_local + local_m;
650-
// TODO: tree reduction here for better acc
651-
if (actual_row < M_ && db_data_ != nullptr) {
652-
sum_beta[i / 2] += static_cast<accscalar_t>(
653-
dY_data_[actual_row * N_ + actual_column]);
652+
// TODO: try tree reduction here if accuracy loss
653+
if (actual_row < M_) {
654+
if constexpr (have_beta) {
655+
sum_beta += static_cast<accscalar_t>(
656+
dY_data_[actual_row * N_ + actual_column]);
657+
}
658+
if constexpr (have_gamma) {
659+
sum_gamma += static_cast<accscalar_t>(
660+
dY_data_[actual_row * N_ + actual_column]) *
661+
(static_cast<accscalar_t>(
662+
X_data_[actual_row * N_ + actual_column]) -
663+
static_cast<accscalar_t>(mean_data_[actual_row])) *
664+
static_cast<accscalar_t>(var_data_[actual_row]);
665+
}
654666
}
655-
if (actual_row < M_ && dg_data_ != nullptr) {
656-
sum_gamma[i / 2] += static_cast<accscalar_t>(
657-
dY_data_[actual_row * N_ + actual_column]) *
658-
(static_cast<accscalar_t>(
659-
X_data_[actual_row * N_ + actual_column]) -
660-
static_cast<accscalar_t>(mean_data_[actual_row])) *
661-
static_cast<accscalar_t>(var_data_[actual_row]);
662-
}
663-
}
664-
for (int i = 0; i < 4; i++) {
665-
sum_beta[i] += sum_beta[i + 4];
666-
sum_gamma[i] += sum_gamma[i + 4];
667667
}
668668

669669
local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
670-
(sum_beta[0] + sum_beta[1]) + (sum_beta[2] + sum_beta[3]);
670+
sum_beta;
671671
local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
672-
(sum_gamma[0] + sum_gamma[1]) + (sum_gamma[2] + sum_gamma[3]);
672+
sum_gamma;
673673
}
674674

675675
// item.barrier(sycl_local_fence);
676-
accscalar_t slm_sum_beta[4] = {accscalar_t(0)};
677-
accscalar_t slm_sum_gamma[4] = {accscalar_t(0)};
676+
accscalar_t slm_sum_beta = accscalar_t(0);
677+
accscalar_t slm_sum_gamma = accscalar_t(0);
678678
// slm row 64, 8 subgroup, i = 0,2,4,6
679679
// slm row 32, 8 subgroup, i = 0,2
680680
// slm row 16, 8 subgroup, i = 0
681681
for (int i = 0; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
682-
i = i + 2) {
683-
slm_sum_beta[i / 2] =
684-
local_sum_beta_
685-
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n] +
686-
local_sum_beta_
687-
[((i + 1) * num_subgroup_ + local_m) * tile_size_n_ + local_n];
688-
slm_sum_gamma[i / 2] =
689-
local_sum_gamma_
690-
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n] +
691-
local_sum_gamma_
692-
[((i + 1) * num_subgroup_ + local_m) * tile_size_n_ + local_n];
682+
i = i + 1) {
683+
slm_sum_beta += local_sum_beta_
684+
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
685+
slm_sum_gamma += local_sum_gamma_
686+
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
693687
}
694-
local_sum_beta_[local_m * tile_size_n_ + local_n] =
695-
(slm_sum_beta[0] + slm_sum_beta[1]) +
696-
(slm_sum_beta[2] + slm_sum_beta[3]);
697-
local_sum_gamma_[local_m * tile_size_n_ + local_n] =
698-
(slm_sum_gamma[0] + slm_sum_gamma[1]) +
699-
(slm_sum_gamma[2] + slm_sum_gamma[3]);
688+
local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta;
689+
local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma;
700690
}
701691
item.barrier(sycl_local_fence);
702-
accscalar_t output_sum_beta[4] = {accscalar_t(0)};
703-
accscalar_t output_sum_gamma[4] = {accscalar_t(0)};
692+
accscalar_t output_sum_beta = accscalar_t(0);
693+
accscalar_t output_sum_gamma = accscalar_t(0);
704694
if (local_m == 0 && actual_column < N_) {
705-
for (int i = 0; i < num_subgroup_; i = i + 2) {
706-
output_sum_beta[i / 2] = local_sum_beta_[i * tile_size_n_ + local_n] +
707-
local_sum_beta_[(i + 1) * tile_size_n_ + local_n];
708-
output_sum_gamma[i / 2] = local_sum_gamma_[i * tile_size_n_ + local_n] +
709-
local_sum_gamma_[(i + 1) * tile_size_n_ + local_n];
695+
for (int i = 0; i < num_subgroup_; i = i + 1) {
696+
output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n];
697+
output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n];
710698
}
711-
if (db_data_ != nullptr)
699+
if constexpr (have_beta) {
712700
db_data_[tile_id_m * N_ + actual_column] =
713-
(static_cast<weight_t>(output_sum_beta[0]) +
714-
static_cast<weight_t>(output_sum_beta[1])) +
715-
(static_cast<weight_t>(output_sum_beta[2]) +
716-
static_cast<weight_t>(output_sum_beta[3]));
717-
if (dg_data_ != nullptr)
701+
static_cast<weight_t>(output_sum_beta);
702+
}
703+
704+
if constexpr (have_gamma) {
718705
dg_data_[tile_id_m * N_ + actual_column] =
719-
(static_cast<weight_t>(output_sum_gamma[0]) +
720-
static_cast<weight_t>(output_sum_gamma[1])) +
721-
(static_cast<weight_t>(output_sum_gamma[2]) +
722-
static_cast<weight_t>(output_sum_gamma[3]));
706+
static_cast<weight_t>(output_sum_gamma);
707+
}
723708
}
724709
}
725710

@@ -739,7 +724,6 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
739724
const scalar_t* X_data,
740725
weight_t* dg_block_data,
741726
weight_t* db_block_data,
742-
int64_t num_tile_m,
743727
int64_t num_tile_n,
744728
int64_t tile_size_m,
745729
int64_t tile_size_n,
@@ -753,7 +737,6 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
753737
X_data_(X_data),
754738
dg_data_(dg_block_data),
755739
db_data_(db_block_data),
756-
num_tile_m_(num_tile_m),
757740
num_tile_n_(num_tile_n),
758741
tile_size_m_(tile_size_m),
759742
tile_size_n_(tile_size_n),
@@ -771,7 +754,6 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
771754
const scalar_t* X_data_;
772755
weight_t* dg_data_;
773756
weight_t* db_data_;
774-
int64_t num_tile_m_;
775757
int64_t num_tile_n_;
776758
int64_t tile_size_m_;
777759
int64_t tile_size_n_;
@@ -1132,35 +1114,132 @@ void _layer_norm_backward_kernel(
11321114
}
11331115

11341116
size_t num_workgroup = num_tile_m * num_tile_n;
1135-
GammaBetaReduceFunctor<scalar_t, accscalar_t, mean_t, weight_t> kfn(
1136-
mean_data,
1137-
var_data,
1138-
dY_data,
1139-
X_data,
1140-
dgamma_blocks_ptr,
1141-
dbeta_blocks_ptr,
1142-
num_tile_m,
1143-
num_tile_n,
1144-
tile_size_m,
1145-
tile_size_n,
1146-
elements_per_thread,
1147-
local_size_x,
1148-
M,
1149-
N);
1150-
1151-
sycl_kernel_submit<
1152-
GammaBetaReduceFunctor<scalar_t, accscalar_t, mean_t, weight_t>,
1153-
3>(
1154-
{num_workgroup, local_size_x, static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1155-
{1, local_size_x, static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1156-
getCurrentSYCLQueue(),
1157-
kfn);
1158-
dgamma = dgamma_blocks.sum(0);
1159-
dbeta = dbeta_blocks.sum(0);
1117+
if (dgamma.defined() && dbeta.defined()) {
1118+
GammaBetaReduceFunctor<
1119+
scalar_t,
1120+
accscalar_t,
1121+
mean_t,
1122+
weight_t,
1123+
true,
1124+
true>
1125+
kfn(mean_data,
1126+
var_data,
1127+
dY_data,
1128+
X_data,
1129+
dgamma_blocks_ptr,
1130+
dbeta_blocks_ptr,
1131+
num_tile_n,
1132+
tile_size_m,
1133+
tile_size_n,
1134+
elements_per_thread,
1135+
local_size_x,
1136+
M,
1137+
N);
1138+
1139+
sycl_kernel_submit<
1140+
GammaBetaReduceFunctor<
1141+
scalar_t,
1142+
accscalar_t,
1143+
mean_t,
1144+
weight_t,
1145+
true,
1146+
true>,
1147+
3>(
1148+
{num_workgroup,
1149+
local_size_x,
1150+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1151+
{1,
1152+
local_size_x,
1153+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1154+
getCurrentSYCLQueue(),
1155+
kfn);
1156+
dgamma = dgamma_blocks.sum(0);
1157+
dbeta = dbeta_blocks.sum(0);
1158+
} else if (dgamma.defined() && !dbeta.defined()) {
1159+
GammaBetaReduceFunctor<
1160+
scalar_t,
1161+
accscalar_t,
1162+
mean_t,
1163+
weight_t,
1164+
true,
1165+
false>
1166+
kfn(mean_data,
1167+
var_data,
1168+
dY_data,
1169+
X_data,
1170+
dgamma_blocks_ptr,
1171+
dbeta_blocks_ptr,
1172+
num_tile_n,
1173+
tile_size_m,
1174+
tile_size_n,
1175+
elements_per_thread,
1176+
local_size_x,
1177+
M,
1178+
N);
1179+
1180+
sycl_kernel_submit<
1181+
GammaBetaReduceFunctor<
1182+
scalar_t,
1183+
accscalar_t,
1184+
mean_t,
1185+
weight_t,
1186+
true,
1187+
false>,
1188+
3>(
1189+
{num_workgroup,
1190+
local_size_x,
1191+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1192+
{1,
1193+
local_size_x,
1194+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1195+
getCurrentSYCLQueue(),
1196+
kfn);
1197+
dgamma = dgamma_blocks.sum(0);
1198+
} else if (!dgamma.defined() && dbeta.defined()) {
1199+
GammaBetaReduceFunctor<
1200+
scalar_t,
1201+
accscalar_t,
1202+
mean_t,
1203+
weight_t,
1204+
false,
1205+
true>
1206+
kfn(mean_data,
1207+
var_data,
1208+
dY_data,
1209+
X_data,
1210+
dgamma_blocks_ptr,
1211+
dbeta_blocks_ptr,
1212+
num_tile_n,
1213+
tile_size_m,
1214+
tile_size_n,
1215+
elements_per_thread,
1216+
local_size_x,
1217+
M,
1218+
N);
1219+
1220+
sycl_kernel_submit<
1221+
GammaBetaReduceFunctor<
1222+
scalar_t,
1223+
accscalar_t,
1224+
mean_t,
1225+
weight_t,
1226+
false,
1227+
true>,
1228+
3>(
1229+
{num_workgroup,
1230+
local_size_x,
1231+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1232+
{1,
1233+
local_size_x,
1234+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1235+
getCurrentSYCLQueue(),
1236+
kfn);
1237+
dbeta = dbeta_blocks.sum(0);
1238+
} else {
1239+
return;
1240+
}
11601241

11611242
} else {
1162-
auto config_w = NormConfig(M, N, 0, sizeof(scalar_t));
1163-
11641243
gamma_beta_bwd_simple_kernel<scalar_t, accscalar_t, mean_t, weight_t>(
11651244
dY, X, mean_data, var_data, dgamma, dbeta, config_w);
11661245
}

0 commit comments

Comments
 (0)