Skip to content

Commit 7651ca2

Browse files
jianyizhCopilot
andauthored
Layernorm bwd OPT (#1880)
I noticed layer norm backward on gamma and beta is very slow when column is much longer. i.e. [M,N] column reduction and M>>N. For example, in timm tnt_s_patch16_224 training, layernorm bwd shape [25088,16,24], normalized shape [24]. it will only launch one workgroup. I use a two staged column reduction to increase parallelism. GammaBetaBackwardSimpleKernelFunctor takes 9 ms on PVC, 8.5ms on BMG. After opt, we use GammaBetaReduceFunctor and two sum to do column reduction, they will take 0.09ms + 0.06ms x2 on PVC and 0.19ms + 0.04ms x 2 on BMG --------- Co-authored-by: Copilot <[email protected]>
1 parent c091232 commit 7651ca2

File tree

1 file changed

+355
-5
lines changed

1 file changed

+355
-5
lines changed

src/ATen/native/xpu/sycl/LayerNormKernels.cpp

Lines changed: 355 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,8 @@ void _layer_norm_kernel(
599599
beta.defined() ? can_vectorize(beta_data, alignment) : true;
600600

601601
if ((std::is_same_v<T, float> || std::is_same_v<T, at::Half> ||
602-
std::is_same_v<T, at::BFloat16>)&&N <=
603-
static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) &&
602+
std::is_same_v<T, at::BFloat16>) &&
603+
N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) &&
604604
N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma &&
605605
can_vec_beta) {
606606
launch_vectorized_layer_norm_kernel(
@@ -620,6 +620,157 @@ void _layer_norm_kernel(
620620
}
621621
}
622622

623+
template <
624+
typename scalar_t,
625+
typename accscalar_t,
626+
typename mean_t,
627+
typename weight_t,
628+
bool have_gamma = true,
629+
bool have_beta = true>
630+
struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
631+
void operator()(sycl::nd_item<3> item) const {
632+
auto local_n = item.get_local_id(2); // [0, 32)
633+
auto local_m = item.get_local_id(1); // [0, 8)
634+
for (auto tile_id = item.get_global_id(0);
635+
tile_id < num_tile_n_ * num_tile_m_;
636+
tile_id += item.get_group_range(0)) {
637+
auto tile_id_n = tile_id % num_tile_n_;
638+
auto tile_id_m = tile_id / num_tile_n_;
639+
auto tile_actual_row_base = tile_id_m * tile_size_m_;
640+
auto tile_actual_col_base = tile_id_n * tile_size_n_;
641+
auto actual_column = tile_actual_col_base + local_n;
642+
if (actual_column < N_) {
643+
// slm_row 0, 8, 16...56
644+
for (auto slm_row = 0; slm_row < tile_size_m_ / elements_per_thread_;
645+
slm_row += num_subgroup_) {
646+
accscalar_t sum_beta = accscalar_t(0);
647+
accscalar_t sum_gamma = accscalar_t(0);
648+
// row 0, 128, 256, ...896
649+
auto row = tile_actual_row_base + slm_row * elements_per_thread_;
650+
for (int i = 0; i < elements_per_thread_; i++) {
651+
// row_local: row + 0, 8, 16, ...120
652+
auto row_local = row + i * num_subgroup_;
653+
auto actual_row = row_local + local_m;
654+
// TODO: try tree reduction here if accuracy loss
655+
if (actual_row < M_) {
656+
if constexpr (have_beta) {
657+
sum_beta += static_cast<accscalar_t>(
658+
dY_data_[actual_row * N_ + actual_column]);
659+
}
660+
if constexpr (have_gamma) {
661+
sum_gamma += static_cast<accscalar_t>(
662+
dY_data_[actual_row * N_ + actual_column]) *
663+
(static_cast<accscalar_t>(
664+
X_data_[actual_row * N_ + actual_column]) -
665+
static_cast<accscalar_t>(mean_data_[actual_row])) *
666+
static_cast<accscalar_t>(var_data_[actual_row]);
667+
}
668+
}
669+
}
670+
671+
local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
672+
sum_beta;
673+
local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
674+
sum_gamma;
675+
}
676+
677+
// item.barrier(sycl_local_fence);
678+
accscalar_t slm_sum_beta = accscalar_t(0);
679+
accscalar_t slm_sum_gamma = accscalar_t(0);
680+
// slm row 64, 8 subgroup, i = 0,2,4,6
681+
// slm row 32, 8 subgroup, i = 0,2
682+
// slm row 16, 8 subgroup, i = 0
683+
for (int i = 0; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
684+
i = i + 1) {
685+
slm_sum_beta += local_sum_beta_
686+
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
687+
slm_sum_gamma += local_sum_gamma_
688+
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
689+
}
690+
local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta;
691+
local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma;
692+
}
693+
item.barrier(sycl_local_fence);
694+
accscalar_t output_sum_beta = accscalar_t(0);
695+
accscalar_t output_sum_gamma = accscalar_t(0);
696+
if (local_m == 0 && actual_column < N_) {
697+
for (int i = 0; i < num_subgroup_; i = i + 1) {
698+
output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n];
699+
output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n];
700+
}
701+
if constexpr (have_beta) {
702+
db_data_[tile_id_m * N_ + actual_column] =
703+
static_cast<weight_t>(output_sum_beta);
704+
}
705+
706+
if constexpr (have_gamma) {
707+
dg_data_[tile_id_m * N_ + actual_column] =
708+
static_cast<weight_t>(output_sum_gamma);
709+
}
710+
}
711+
}
712+
}
713+
714+
void sycl_ker_config_convention(sycl::handler& cgh) {
715+
local_sum_beta_ = sycl_local_acc_t<accscalar_t, 1>(
716+
sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_),
717+
cgh);
718+
local_sum_gamma_ = sycl_local_acc_t<accscalar_t, 1>(
719+
sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_),
720+
cgh);
721+
}
722+
723+
GammaBetaReduceFunctor(
724+
const mean_t* mean_data,
725+
const mean_t* var_data,
726+
const scalar_t* dY_data,
727+
const scalar_t* X_data,
728+
weight_t* dg_block_data,
729+
weight_t* db_block_data,
730+
int64_t num_tile_m,
731+
int64_t num_tile_n,
732+
int64_t tile_size_m,
733+
int64_t tile_size_n,
734+
int64_t elements_per_thread,
735+
int64_t num_subgroup,
736+
int64_t M,
737+
int64_t N)
738+
: mean_data_(mean_data),
739+
var_data_(var_data),
740+
dY_data_(dY_data),
741+
X_data_(X_data),
742+
dg_data_(dg_block_data),
743+
db_data_(db_block_data),
744+
num_tile_m_(num_tile_m),
745+
num_tile_n_(num_tile_n),
746+
tile_size_m_(tile_size_m),
747+
tile_size_n_(tile_size_n),
748+
elements_per_thread_(elements_per_thread),
749+
num_subgroup_(num_subgroup),
750+
M_(M),
751+
N_(N),
752+
local_sum_beta_(),
753+
local_sum_gamma_() {}
754+
755+
private:
756+
const mean_t* mean_data_;
757+
const mean_t* var_data_;
758+
const scalar_t* dY_data_;
759+
const scalar_t* X_data_;
760+
weight_t* dg_data_;
761+
weight_t* db_data_;
762+
int64_t num_tile_m_;
763+
int64_t num_tile_n_;
764+
int64_t tile_size_m_;
765+
int64_t tile_size_n_;
766+
int64_t elements_per_thread_;
767+
int64_t num_subgroup_;
768+
int64_t M_;
769+
int64_t N_;
770+
sycl_local_acc_t<accscalar_t, 1> local_sum_beta_;
771+
sycl_local_acc_t<accscalar_t, 1> local_sum_gamma_;
772+
};
773+
623774
template <
624775
typename scalar_t,
625776
typename accscalar_t,
@@ -900,10 +1051,209 @@ void _layer_norm_backward_kernel(
9001051
norm, config, can_use_32bit_index);
9011052
}
9021053
}
903-
9041054
auto config_w = NormConfig(M, N, 0, sizeof(scalar_t));
905-
gamma_beta_bwd_simple_kernel<scalar_t, accscalar_t, mean_t, weight_t>(
906-
dY, X, mean_data, var_data, dgamma, dbeta, config_w);
1055+
auto norm_config_global_size =
1056+
config_w.workgroup_num * config_w.block_row * config_w.workgroup_size;
1057+
int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU();
1058+
// use two stage col reduction if norm config occupancy < 50%
1059+
// TODO: we can relax this restriction in future for better perf
1060+
bool use_two_stage_col_reduction =
1061+
(dY.dtype() == kFloat || dY.dtype() == kBFloat16 ||
1062+
dY.dtype() == kHalf) &&
1063+
norm_config_global_size / syclMaxSubGroupSize() * 2 <= thread_slots;
1064+
// cuda uses condition M > 64 * 1024 && N / 32 < sm_count / 2 to parallelize
1065+
// in the M dimension
1066+
if (use_two_stage_col_reduction && M > 64 * 1024 &&
1067+
N / 32 < syclGpuEuCount() / syclGpuEUCountPerSubslice() / 2) {
1068+
const size_t local_size_x = 8;
1069+
const size_t SIMD = 32;
1070+
// workgroup size is 256
1071+
// slm is 16KB, 64*32 float * 2
1072+
// elements_per_thread is at least 16
1073+
const int elements_per_thread = 16;
1074+
int tile_size_m = 1024;
1075+
int tile_size_n = N < 32 ? N : 32;
1076+
int num_tile_m = (M + tile_size_m - 1) / tile_size_m;
1077+
int num_tile_n = (N + tile_size_n - 1) / tile_size_n;
1078+
bool adjust_m = true;
1079+
// for M = 64*1024, N = 1, we choose tile size (256, 16) on pvc
1080+
// TODO: Consider tuning the tile size selection logic (tile_size_m, tile_size_n) and occupancy calculation
1081+
for (auto i = 0; i < 3; i++) {
1082+
// occupancy <= 50%
1083+
if (num_tile_m * num_tile_n * local_size_x * SIMD /
1084+
syclMaxSubGroupSize() * 2 <=
1085+
thread_slots) {
1086+
if (adjust_m) {
1087+
tile_size_m /= 2;
1088+
num_tile_m = (M + tile_size_m - 1) / tile_size_m;
1089+
adjust_m = false;
1090+
} else {
1091+
tile_size_n /= 2;
1092+
num_tile_n = (N + tile_size_n - 1) / tile_size_n;
1093+
adjust_m = true;
1094+
}
1095+
} else {
1096+
break;
1097+
}
1098+
}
1099+
// tile size can be (1024,32), (512,32), (512,16), (256, 16)
1100+
// Modifying these parameters (num_subgroup, workgroup_size, tile_size, elements_per_thread)
1101+
// will alter the kernel configuration, potentially affecting performance and behavior.
1102+
const scalar_t* dY_data = dY.const_data_ptr<scalar_t>();
1103+
const scalar_t* X_data = X.const_data_ptr<scalar_t>();
1104+
weight_t* dg_data =
1105+
dgamma.defined() ? dgamma.data_ptr<weight_t>() : nullptr;
1106+
weight_t* db_data = dbeta.defined() ? dbeta.data_ptr<weight_t>() : nullptr;
1107+
Tensor dgamma_blocks;
1108+
Tensor dbeta_blocks;
1109+
weight_t* dgamma_blocks_ptr = nullptr;
1110+
weight_t* dbeta_blocks_ptr = nullptr;
1111+
if (dgamma.defined()) {
1112+
auto options = dgamma.options();
1113+
// TODO: how to set dgamma_blocks dtype = float32?
1114+
dgamma_blocks = at::empty({num_tile_m, N}, options);
1115+
dgamma_blocks_ptr = dgamma_blocks.data_ptr<weight_t>();
1116+
}
1117+
if (dbeta.defined()) {
1118+
auto options = dbeta.options();
1119+
dbeta_blocks = at::empty({num_tile_m, N}, options);
1120+
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>();
1121+
}
1122+
1123+
size_t num_workgroup =
1124+
std::min(num_tile_m * num_tile_n, static_cast<int>(thread_slots / local_size_x));
1125+
if (dgamma.defined() && dbeta.defined()) {
1126+
GammaBetaReduceFunctor<
1127+
scalar_t,
1128+
accscalar_t,
1129+
mean_t,
1130+
weight_t,
1131+
true,
1132+
true>
1133+
kfn(mean_data,
1134+
var_data,
1135+
dY_data,
1136+
X_data,
1137+
dgamma_blocks_ptr,
1138+
dbeta_blocks_ptr,
1139+
num_tile_m,
1140+
num_tile_n,
1141+
tile_size_m,
1142+
tile_size_n,
1143+
elements_per_thread,
1144+
local_size_x,
1145+
M,
1146+
N);
1147+
1148+
sycl_kernel_submit<
1149+
GammaBetaReduceFunctor<
1150+
scalar_t,
1151+
accscalar_t,
1152+
mean_t,
1153+
weight_t,
1154+
true,
1155+
true>,
1156+
3>(
1157+
{num_workgroup,
1158+
local_size_x,
1159+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1160+
{1,
1161+
local_size_x,
1162+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1163+
getCurrentSYCLQueue(),
1164+
kfn);
1165+
dgamma = dgamma_blocks.sum(0);
1166+
dbeta = dbeta_blocks.sum(0);
1167+
} else if (dgamma.defined() && !dbeta.defined()) {
1168+
GammaBetaReduceFunctor<
1169+
scalar_t,
1170+
accscalar_t,
1171+
mean_t,
1172+
weight_t,
1173+
true,
1174+
false>
1175+
kfn(mean_data,
1176+
var_data,
1177+
dY_data,
1178+
X_data,
1179+
dgamma_blocks_ptr,
1180+
dbeta_blocks_ptr,
1181+
num_tile_m,
1182+
num_tile_n,
1183+
tile_size_m,
1184+
tile_size_n,
1185+
elements_per_thread,
1186+
local_size_x,
1187+
M,
1188+
N);
1189+
1190+
sycl_kernel_submit<
1191+
GammaBetaReduceFunctor<
1192+
scalar_t,
1193+
accscalar_t,
1194+
mean_t,
1195+
weight_t,
1196+
true,
1197+
false>,
1198+
3>(
1199+
{num_workgroup,
1200+
local_size_x,
1201+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1202+
{1,
1203+
local_size_x,
1204+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1205+
getCurrentSYCLQueue(),
1206+
kfn);
1207+
dgamma = dgamma_blocks.sum(0);
1208+
} else if (!dgamma.defined() && dbeta.defined()) {
1209+
GammaBetaReduceFunctor<
1210+
scalar_t,
1211+
accscalar_t,
1212+
mean_t,
1213+
weight_t,
1214+
false,
1215+
true>
1216+
kfn(mean_data,
1217+
var_data,
1218+
dY_data,
1219+
X_data,
1220+
dgamma_blocks_ptr,
1221+
dbeta_blocks_ptr,
1222+
num_tile_m,
1223+
num_tile_n,
1224+
tile_size_m,
1225+
tile_size_n,
1226+
elements_per_thread,
1227+
local_size_x,
1228+
M,
1229+
N);
1230+
1231+
sycl_kernel_submit<
1232+
GammaBetaReduceFunctor<
1233+
scalar_t,
1234+
accscalar_t,
1235+
mean_t,
1236+
weight_t,
1237+
false,
1238+
true>,
1239+
3>(
1240+
{num_workgroup,
1241+
local_size_x,
1242+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1243+
{1,
1244+
local_size_x,
1245+
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1246+
getCurrentSYCLQueue(),
1247+
kfn);
1248+
dbeta = dbeta_blocks.sum(0);
1249+
} else {
1250+
return;
1251+
}
1252+
1253+
} else {
1254+
gamma_beta_bwd_simple_kernel<scalar_t, accscalar_t, mean_t, weight_t>(
1255+
dY, X, mean_data, var_data, dgamma, dbeta, config_w);
1256+
}
9071257
}
9081258

9091259
void layer_norm_kernel(

0 commit comments

Comments
 (0)