Skip to content

Layernorm bwd OPT #1880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 354 additions & 5 deletions src/ATen/native/xpu/sycl/LayerNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,8 @@ void _layer_norm_kernel(
beta.defined() ? can_vectorize(beta_data, alignment) : true;

if ((std::is_same_v<T, float> || std::is_same_v<T, at::Half> ||
std::is_same_v<T, at::BFloat16>)&&N <=
static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) &&
std::is_same_v<T, at::BFloat16>) &&
N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) &&
N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma &&
can_vec_beta) {
Copy link
Preview

Copilot AI Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The condition formatting is inconsistent. The && operator should be aligned with the opening parenthesis or consistently indented.

Copilot uses AI. Check for mistakes.

launch_vectorized_layer_norm_kernel(
Expand All @@ -620,6 +620,157 @@ void _layer_norm_kernel(
}
}

template <
typename scalar_t,
typename accscalar_t,
typename mean_t,
typename weight_t,
bool have_gamma = true,
bool have_beta = true>
struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
void operator()(sycl::nd_item<3> item) const {
auto local_n = item.get_local_id(2); // [0, 32)
auto local_m = item.get_local_id(1); // [0, 8)
for (auto tile_id = item.get_global_id(0);
tile_id < num_tile_n_ * num_tile_m_;
tile_id += item.get_group_range(0)) {
auto tile_id_n = tile_id % num_tile_n_;
auto tile_id_m = tile_id / num_tile_n_;
auto tile_actual_row_base = tile_id_m * tile_size_m_;
auto tile_actual_col_base = tile_id_n * tile_size_n_;
auto actual_column = tile_actual_col_base + local_n;
if (actual_column < N_) {
// slm_row 0, 8, 16...56
for (auto slm_row = 0; slm_row < tile_size_m_ / elements_per_thread_;
slm_row += num_subgroup_) {
accscalar_t sum_beta = accscalar_t(0);
accscalar_t sum_gamma = accscalar_t(0);
// row 0, 128, 256, ...896
auto row = tile_actual_row_base + slm_row * elements_per_thread_;
for (int i = 0; i < elements_per_thread_; i++) {
// row_local: row + 0, 8, 16, ...120
auto row_local = row + i * num_subgroup_;
auto actual_row = row_local + local_m;
// TODO: try tree reduction here if accuracy loss
if (actual_row < M_) {
if constexpr (have_beta) {
sum_beta += static_cast<accscalar_t>(
dY_data_[actual_row * N_ + actual_column]);
}
if constexpr (have_gamma) {
sum_gamma += static_cast<accscalar_t>(
dY_data_[actual_row * N_ + actual_column]) *
(static_cast<accscalar_t>(
X_data_[actual_row * N_ + actual_column]) -
static_cast<accscalar_t>(mean_data_[actual_row])) *
static_cast<accscalar_t>(var_data_[actual_row]);
}
}
}

local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] =
sum_beta;
local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] =
sum_gamma;
}

// item.barrier(sycl_local_fence);
accscalar_t slm_sum_beta = accscalar_t(0);
accscalar_t slm_sum_gamma = accscalar_t(0);
// slm row 64, 8 subgroup, i = 0,2,4,6
// slm row 32, 8 subgroup, i = 0,2
// slm row 16, 8 subgroup, i = 0
for (int i = 0; i < tile_size_m_ / elements_per_thread_ / num_subgroup_;
i = i + 1) {
slm_sum_beta += local_sum_beta_
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
slm_sum_gamma += local_sum_gamma_
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n];
}
local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta;
local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma;
}
item.barrier(sycl_local_fence);
accscalar_t output_sum_beta = accscalar_t(0);
accscalar_t output_sum_gamma = accscalar_t(0);
if (local_m == 0 && actual_column < N_) {
for (int i = 0; i < num_subgroup_; i = i + 1) {
output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n];
output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n];
}
if constexpr (have_beta) {
db_data_[tile_id_m * N_ + actual_column] =
static_cast<weight_t>(output_sum_beta);
}

if constexpr (have_gamma) {
dg_data_[tile_id_m * N_ + actual_column] =
static_cast<weight_t>(output_sum_gamma);
}
}
}
}

void sycl_ker_config_convention(sycl::handler& cgh) {
local_sum_beta_ = sycl_local_acc_t<accscalar_t, 1>(
sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_),
cgh);
local_sum_gamma_ = sycl_local_acc_t<accscalar_t, 1>(
sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_),
cgh);
}

GammaBetaReduceFunctor(
const mean_t* mean_data,
const mean_t* var_data,
const scalar_t* dY_data,
const scalar_t* X_data,
weight_t* dg_block_data,
weight_t* db_block_data,
int64_t num_tile_m,
int64_t num_tile_n,
int64_t tile_size_m,
int64_t tile_size_n,
int64_t elements_per_thread,
int64_t num_subgroup,
int64_t M,
int64_t N)
: mean_data_(mean_data),
var_data_(var_data),
dY_data_(dY_data),
X_data_(X_data),
dg_data_(dg_block_data),
db_data_(db_block_data),
num_tile_m_(num_tile_m),
num_tile_n_(num_tile_n),
tile_size_m_(tile_size_m),
tile_size_n_(tile_size_n),
elements_per_thread_(elements_per_thread),
num_subgroup_(num_subgroup),
M_(M),
N_(N),
local_sum_beta_(),
local_sum_gamma_() {}

private:
const mean_t* mean_data_;
const mean_t* var_data_;
const scalar_t* dY_data_;
const scalar_t* X_data_;
weight_t* dg_data_;
weight_t* db_data_;
int64_t num_tile_m_;
int64_t num_tile_n_;
int64_t tile_size_m_;
int64_t tile_size_n_;
int64_t elements_per_thread_;
int64_t num_subgroup_;
int64_t M_;
int64_t N_;
sycl_local_acc_t<accscalar_t, 1> local_sum_beta_;
sycl_local_acc_t<accscalar_t, 1> local_sum_gamma_;
};

template <
typename scalar_t,
typename accscalar_t,
Expand Down Expand Up @@ -900,10 +1051,208 @@ void _layer_norm_backward_kernel(
norm, config, can_use_32bit_index);
}
}

auto config_w = NormConfig(M, N, 0, sizeof(scalar_t));
gamma_beta_bwd_simple_kernel<scalar_t, accscalar_t, mean_t, weight_t>(
dY, X, mean_data, var_data, dgamma, dbeta, config_w);
auto norm_config_global_size =
config_w.workgroup_num * config_w.block_row * config_w.workgroup_size;
int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU();
// use two stage col reduction if norm config occupancy < 50%
// TODO: we can releax this restriction in future for better perf
bool use_two_stage_col_reduction =
(dY.dtype() == kFloat || dY.dtype() == kBFloat16 ||
dY.dtype() == kHalf) &&
norm_config_global_size / syclMaxSubGroupSize() * 2 <= thread_slots;
// cuda uses condition M > 64 * 1024 && N / 32 < sm_count / 2 to parallelize
// in the M dimension
if (use_two_stage_col_reduction && M > 64 * 1024 &&
N / 32 < syclGpuEuCount() / syclGpuEUCountPerSubslice() / 2) {
const size_t local_size_x = 8;
const size_t SIMD = 32;
// workgroup size is 256
// slm is 16KB, 64*32 float * 2
// elements_per_thread is at least 16
const int elements_per_thread = 16;
int tile_size_m = 1024;
int tile_size_n = N < 32 ? N : 32;
int num_tile_m = (M + tile_size_m - 1) / tile_size_m;
int num_tile_n = (N + tile_size_n - 1) / tile_size_n;
bool adjust_m = true;
// for M = 64*1024, N = 1, we choose tile size (256, 16) on pvc
// TODO: we can tune these conditions in future
for (auto i = 0; i < 3; i++) {
// occupancy <= 50%
if (num_tile_m * num_tile_n * local_size_x * SIMD /
syclMaxSubGroupSize() * 2 <=
thread_slots) {
if (adjust_m) {
tile_size_m /= 2;
num_tile_m = (M + tile_size_m - 1) / tile_size_m;
adjust_m = false;
} else {
tile_size_n /= 2;
num_tile_n = (N + tile_size_n - 1) / tile_size_n;
adjust_m = true;
}
} else {
break;
}
}
// tile size can be (1024,32), (512,32), (512,16), (256, 16)
// Change these parameters will cause changes in kernel
const scalar_t* dY_data = dY.const_data_ptr<scalar_t>();
const scalar_t* X_data = X.const_data_ptr<scalar_t>();
weight_t* dg_data =
dgamma.defined() ? dgamma.data_ptr<weight_t>() : nullptr;
weight_t* db_data = dbeta.defined() ? dbeta.data_ptr<weight_t>() : nullptr;
Tensor dgamma_blocks;
Tensor dbeta_blocks;
weight_t* dgamma_blocks_ptr = nullptr;
weight_t* dbeta_blocks_ptr = nullptr;
if (dgamma.defined()) {
auto options = dgamma.options();
// TODO: how to set dgamma_blocks dtype = float32?
dgamma_blocks = at::empty({num_tile_m, N}, options);
dgamma_blocks_ptr = dgamma_blocks.data_ptr<weight_t>();
}
if (dbeta.defined()) {
auto options = dbeta.options();
dbeta_blocks = at::empty({num_tile_m, N}, options);
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>();
Copy link
Preview

Copilot AI Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This TODO comment suggests uncertainty about the data type handling. The comment should either be resolved or provide more context about why float32 might be needed and what the current behavior is.

Suggested change
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>();
// Set dgamma_blocks dtype to float32 for numerical stability in reduction
dgamma_blocks = at::empty({num_tile_m, N}, options.dtype(at::kFloat));
dgamma_blocks_ptr = dgamma_blocks.data_ptr<float>();
}
if (dbeta.defined()) {
auto options = dbeta.options();
dbeta_blocks = at::empty({num_tile_m, N}, options.dtype(at::kFloat));
dbeta_blocks_ptr = dbeta_blocks.data_ptr<float>();

Copilot uses AI. Check for mistakes.

}

size_t num_workgroup =
std::min(num_tile_m * num_tile_n, static_cast<int>(thread_slots / local_size_x));
if (dgamma.defined() && dbeta.defined()) {
GammaBetaReduceFunctor<
scalar_t,
accscalar_t,
mean_t,
weight_t,
true,
true>
kfn(mean_data,
var_data,
dY_data,
X_data,
dgamma_blocks_ptr,
dbeta_blocks_ptr,
num_tile_m,
num_tile_n,
tile_size_m,
tile_size_n,
elements_per_thread,
local_size_x,
M,
N);

sycl_kernel_submit<
GammaBetaReduceFunctor<
scalar_t,
accscalar_t,
mean_t,
weight_t,
true,
true>,
3>(
{num_workgroup,
local_size_x,
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
{1,
local_size_x,
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
getCurrentSYCLQueue(),
kfn);
dgamma = dgamma_blocks.sum(0);
dbeta = dbeta_blocks.sum(0);
} else if (dgamma.defined() && !dbeta.defined()) {
GammaBetaReduceFunctor<
scalar_t,
accscalar_t,
mean_t,
weight_t,
true,
false>
kfn(mean_data,
var_data,
dY_data,
X_data,
dgamma_blocks_ptr,
dbeta_blocks_ptr,
num_tile_m,
num_tile_n,
tile_size_m,
tile_size_n,
elements_per_thread,
local_size_x,
M,
N);

sycl_kernel_submit<
GammaBetaReduceFunctor<
scalar_t,
accscalar_t,
mean_t,
weight_t,
true,
false>,
3>(
{num_workgroup,
local_size_x,
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
{1,
local_size_x,
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
getCurrentSYCLQueue(),
kfn);
dgamma = dgamma_blocks.sum(0);
} else if (!dgamma.defined() && dbeta.defined()) {
GammaBetaReduceFunctor<
scalar_t,
accscalar_t,
mean_t,
weight_t,
false,
true>
kfn(mean_data,
var_data,
dY_data,
X_data,
dgamma_blocks_ptr,
dbeta_blocks_ptr,
num_tile_m,
num_tile_n,
tile_size_m,
tile_size_n,
elements_per_thread,
local_size_x,
M,
N);

sycl_kernel_submit<
GammaBetaReduceFunctor<
scalar_t,
accscalar_t,
mean_t,
weight_t,
false,
true>,
3>(
{num_workgroup,
local_size_x,
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
{1,
local_size_x,
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
getCurrentSYCLQueue(),
kfn);
dbeta = dbeta_blocks.sum(0);
} else {
return;
}

} else {
gamma_beta_bwd_simple_kernel<scalar_t, accscalar_t, mean_t, weight_t>(
dY, X, mean_data, var_data, dgamma, dbeta, config_w);
}
}

void layer_norm_kernel(
Expand Down