-
Notifications
You must be signed in to change notification settings - Fork 49
Layernorm bwd OPT #1880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Layernorm bwd OPT #1880
Changes from 5 commits
e21ccec
e0f15de
02b55ff
7a06f6c
d0149b8
8f1fcd7
d94ca74
e8bef72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -599,8 +599,8 @@ void _layer_norm_kernel( | |||||||||||||||||||
beta.defined() ? can_vectorize(beta_data, alignment) : true; | ||||||||||||||||||||
|
||||||||||||||||||||
if ((std::is_same_v<T, float> || std::is_same_v<T, at::Half> || | ||||||||||||||||||||
std::is_same_v<T, at::BFloat16>)&&N <= | ||||||||||||||||||||
static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && | ||||||||||||||||||||
std::is_same_v<T, at::BFloat16>) && | ||||||||||||||||||||
N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && | ||||||||||||||||||||
N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && | ||||||||||||||||||||
can_vec_beta) { | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] The condition formatting is inconsistent. The Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||
launch_vectorized_layer_norm_kernel( | ||||||||||||||||||||
|
@@ -620,6 +620,157 @@ void _layer_norm_kernel( | |||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
template < | ||||||||||||||||||||
typename scalar_t, | ||||||||||||||||||||
typename accscalar_t, | ||||||||||||||||||||
typename mean_t, | ||||||||||||||||||||
typename weight_t, | ||||||||||||||||||||
bool have_gamma = true, | ||||||||||||||||||||
bool have_beta = true> | ||||||||||||||||||||
struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { | ||||||||||||||||||||
void operator()(sycl::nd_item<3> item) const { | ||||||||||||||||||||
auto local_n = item.get_local_id(2); // [0, 32) | ||||||||||||||||||||
auto local_m = item.get_local_id(1); // [0, 8) | ||||||||||||||||||||
for (auto tile_id = item.get_global_id(0); | ||||||||||||||||||||
tile_id < num_tile_n_ * num_tile_m_; | ||||||||||||||||||||
tile_id += item.get_group_range(0)) { | ||||||||||||||||||||
auto tile_id_n = tile_id % num_tile_n_; | ||||||||||||||||||||
auto tile_id_m = tile_id / num_tile_n_; | ||||||||||||||||||||
auto tile_actual_row_base = tile_id_m * tile_size_m_; | ||||||||||||||||||||
auto tile_actual_col_base = tile_id_n * tile_size_n_; | ||||||||||||||||||||
auto actual_column = tile_actual_col_base + local_n; | ||||||||||||||||||||
if (actual_column < N_) { | ||||||||||||||||||||
// slm_row 0, 8, 16...56 | ||||||||||||||||||||
for (auto slm_row = 0; slm_row < tile_size_m_ / elements_per_thread_; | ||||||||||||||||||||
slm_row += num_subgroup_) { | ||||||||||||||||||||
accscalar_t sum_beta = accscalar_t(0); | ||||||||||||||||||||
accscalar_t sum_gamma = accscalar_t(0); | ||||||||||||||||||||
// row 0, 128, 256, ...896 | ||||||||||||||||||||
auto row = tile_actual_row_base + slm_row * elements_per_thread_; | ||||||||||||||||||||
for (int i = 0; i < elements_per_thread_; i++) { | ||||||||||||||||||||
// row_local: row + 0, 8, 16, ...120 | ||||||||||||||||||||
auto row_local = row + i * num_subgroup_; | ||||||||||||||||||||
auto actual_row = row_local + local_m; | ||||||||||||||||||||
// TODO: try tree reduction here if accuracy loss | ||||||||||||||||||||
if (actual_row < M_) { | ||||||||||||||||||||
if constexpr (have_beta) { | ||||||||||||||||||||
sum_beta += static_cast<accscalar_t>( | ||||||||||||||||||||
dY_data_[actual_row * N_ + actual_column]); | ||||||||||||||||||||
} | ||||||||||||||||||||
if constexpr (have_gamma) { | ||||||||||||||||||||
sum_gamma += static_cast<accscalar_t>( | ||||||||||||||||||||
dY_data_[actual_row * N_ + actual_column]) * | ||||||||||||||||||||
(static_cast<accscalar_t>( | ||||||||||||||||||||
X_data_[actual_row * N_ + actual_column]) - | ||||||||||||||||||||
static_cast<accscalar_t>(mean_data_[actual_row])) * | ||||||||||||||||||||
static_cast<accscalar_t>(var_data_[actual_row]); | ||||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] = | ||||||||||||||||||||
sum_beta; | ||||||||||||||||||||
local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] = | ||||||||||||||||||||
sum_gamma; | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
// item.barrier(sycl_local_fence); | ||||||||||||||||||||
accscalar_t slm_sum_beta = accscalar_t(0); | ||||||||||||||||||||
accscalar_t slm_sum_gamma = accscalar_t(0); | ||||||||||||||||||||
// slm row 64, 8 subgroup, i = 0,2,4,6 | ||||||||||||||||||||
// slm row 32, 8 subgroup, i = 0,2 | ||||||||||||||||||||
// slm row 16, 8 subgroup, i = 0 | ||||||||||||||||||||
for (int i = 0; i < tile_size_m_ / elements_per_thread_ / num_subgroup_; | ||||||||||||||||||||
i = i + 1) { | ||||||||||||||||||||
slm_sum_beta += local_sum_beta_ | ||||||||||||||||||||
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n]; | ||||||||||||||||||||
slm_sum_gamma += local_sum_gamma_ | ||||||||||||||||||||
[(i * num_subgroup_ + local_m) * tile_size_n_ + local_n]; | ||||||||||||||||||||
} | ||||||||||||||||||||
local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta; | ||||||||||||||||||||
local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma; | ||||||||||||||||||||
} | ||||||||||||||||||||
item.barrier(sycl_local_fence); | ||||||||||||||||||||
accscalar_t output_sum_beta = accscalar_t(0); | ||||||||||||||||||||
accscalar_t output_sum_gamma = accscalar_t(0); | ||||||||||||||||||||
if (local_m == 0 && actual_column < N_) { | ||||||||||||||||||||
for (int i = 0; i < num_subgroup_; i = i + 1) { | ||||||||||||||||||||
output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n]; | ||||||||||||||||||||
output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n]; | ||||||||||||||||||||
} | ||||||||||||||||||||
if constexpr (have_beta) { | ||||||||||||||||||||
db_data_[tile_id_m * N_ + actual_column] = | ||||||||||||||||||||
static_cast<weight_t>(output_sum_beta); | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
if constexpr (have_gamma) { | ||||||||||||||||||||
dg_data_[tile_id_m * N_ + actual_column] = | ||||||||||||||||||||
static_cast<weight_t>(output_sum_gamma); | ||||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
void sycl_ker_config_convention(sycl::handler& cgh) { | ||||||||||||||||||||
local_sum_beta_ = sycl_local_acc_t<accscalar_t, 1>( | ||||||||||||||||||||
sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_), | ||||||||||||||||||||
cgh); | ||||||||||||||||||||
local_sum_gamma_ = sycl_local_acc_t<accscalar_t, 1>( | ||||||||||||||||||||
sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_), | ||||||||||||||||||||
cgh); | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
GammaBetaReduceFunctor( | ||||||||||||||||||||
const mean_t* mean_data, | ||||||||||||||||||||
const mean_t* var_data, | ||||||||||||||||||||
const scalar_t* dY_data, | ||||||||||||||||||||
const scalar_t* X_data, | ||||||||||||||||||||
weight_t* dg_block_data, | ||||||||||||||||||||
weight_t* db_block_data, | ||||||||||||||||||||
int64_t num_tile_m, | ||||||||||||||||||||
int64_t num_tile_n, | ||||||||||||||||||||
int64_t tile_size_m, | ||||||||||||||||||||
int64_t tile_size_n, | ||||||||||||||||||||
int64_t elements_per_thread, | ||||||||||||||||||||
int64_t num_subgroup, | ||||||||||||||||||||
int64_t M, | ||||||||||||||||||||
int64_t N) | ||||||||||||||||||||
: mean_data_(mean_data), | ||||||||||||||||||||
var_data_(var_data), | ||||||||||||||||||||
dY_data_(dY_data), | ||||||||||||||||||||
X_data_(X_data), | ||||||||||||||||||||
dg_data_(dg_block_data), | ||||||||||||||||||||
db_data_(db_block_data), | ||||||||||||||||||||
num_tile_m_(num_tile_m), | ||||||||||||||||||||
num_tile_n_(num_tile_n), | ||||||||||||||||||||
tile_size_m_(tile_size_m), | ||||||||||||||||||||
tile_size_n_(tile_size_n), | ||||||||||||||||||||
elements_per_thread_(elements_per_thread), | ||||||||||||||||||||
num_subgroup_(num_subgroup), | ||||||||||||||||||||
M_(M), | ||||||||||||||||||||
N_(N), | ||||||||||||||||||||
local_sum_beta_(), | ||||||||||||||||||||
local_sum_gamma_() {} | ||||||||||||||||||||
|
||||||||||||||||||||
private: | ||||||||||||||||||||
const mean_t* mean_data_; | ||||||||||||||||||||
const mean_t* var_data_; | ||||||||||||||||||||
const scalar_t* dY_data_; | ||||||||||||||||||||
const scalar_t* X_data_; | ||||||||||||||||||||
weight_t* dg_data_; | ||||||||||||||||||||
weight_t* db_data_; | ||||||||||||||||||||
int64_t num_tile_m_; | ||||||||||||||||||||
int64_t num_tile_n_; | ||||||||||||||||||||
int64_t tile_size_m_; | ||||||||||||||||||||
int64_t tile_size_n_; | ||||||||||||||||||||
int64_t elements_per_thread_; | ||||||||||||||||||||
int64_t num_subgroup_; | ||||||||||||||||||||
int64_t M_; | ||||||||||||||||||||
int64_t N_; | ||||||||||||||||||||
sycl_local_acc_t<accscalar_t, 1> local_sum_beta_; | ||||||||||||||||||||
sycl_local_acc_t<accscalar_t, 1> local_sum_gamma_; | ||||||||||||||||||||
}; | ||||||||||||||||||||
|
||||||||||||||||||||
template < | ||||||||||||||||||||
typename scalar_t, | ||||||||||||||||||||
typename accscalar_t, | ||||||||||||||||||||
|
@@ -900,10 +1051,208 @@ void _layer_norm_backward_kernel( | |||||||||||||||||||
norm, config, can_use_32bit_index); | ||||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
auto config_w = NormConfig(M, N, 0, sizeof(scalar_t)); | ||||||||||||||||||||
gamma_beta_bwd_simple_kernel<scalar_t, accscalar_t, mean_t, weight_t>( | ||||||||||||||||||||
dY, X, mean_data, var_data, dgamma, dbeta, config_w); | ||||||||||||||||||||
auto norm_config_global_size = | ||||||||||||||||||||
config_w.workgroup_num * config_w.block_row * config_w.workgroup_size; | ||||||||||||||||||||
int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); | ||||||||||||||||||||
// use two stage col reduction if norm config occupancy < 50% | ||||||||||||||||||||
// TODO: we can releax this restriction in future for better perf | ||||||||||||||||||||
jianyizh marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
bool use_two_stage_col_reduction = | ||||||||||||||||||||
(dY.dtype() == kFloat || dY.dtype() == kBFloat16 || | ||||||||||||||||||||
dY.dtype() == kHalf) && | ||||||||||||||||||||
norm_config_global_size / syclMaxSubGroupSize() * 2 <= thread_slots; | ||||||||||||||||||||
// cuda uses condition M > 64 * 1024 && N / 32 < sm_count / 2 to parallelize | ||||||||||||||||||||
// in the M dimension | ||||||||||||||||||||
if (use_two_stage_col_reduction && M > 64 * 1024 && | ||||||||||||||||||||
N / 32 < syclGpuEuCount() / syclGpuEUCountPerSubslice() / 2) { | ||||||||||||||||||||
const size_t local_size_x = 8; | ||||||||||||||||||||
const size_t SIMD = 32; | ||||||||||||||||||||
// workgroup size is 256 | ||||||||||||||||||||
// slm is 16KB, 64*32 float * 2 | ||||||||||||||||||||
// elements_per_thread is at least 16 | ||||||||||||||||||||
const int elements_per_thread = 16; | ||||||||||||||||||||
int tile_size_m = 1024; | ||||||||||||||||||||
int tile_size_n = N < 32 ? N : 32; | ||||||||||||||||||||
int num_tile_m = (M + tile_size_m - 1) / tile_size_m; | ||||||||||||||||||||
int num_tile_n = (N + tile_size_n - 1) / tile_size_n; | ||||||||||||||||||||
bool adjust_m = true; | ||||||||||||||||||||
// for M = 64*1024, N = 1, we choose tile size (256, 16) on pvc | ||||||||||||||||||||
// TODO: we can tune these conditions in future | ||||||||||||||||||||
jianyizh marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
for (auto i = 0; i < 3; i++) { | ||||||||||||||||||||
// occupancy <= 50% | ||||||||||||||||||||
if (num_tile_m * num_tile_n * local_size_x * SIMD / | ||||||||||||||||||||
syclMaxSubGroupSize() * 2 <= | ||||||||||||||||||||
thread_slots) { | ||||||||||||||||||||
if (adjust_m) { | ||||||||||||||||||||
tile_size_m /= 2; | ||||||||||||||||||||
num_tile_m = (M + tile_size_m - 1) / tile_size_m; | ||||||||||||||||||||
adjust_m = false; | ||||||||||||||||||||
} else { | ||||||||||||||||||||
tile_size_n /= 2; | ||||||||||||||||||||
num_tile_n = (N + tile_size_n - 1) / tile_size_n; | ||||||||||||||||||||
adjust_m = true; | ||||||||||||||||||||
} | ||||||||||||||||||||
} else { | ||||||||||||||||||||
break; | ||||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
// tile size can be (1024,32), (512,32), (512,16), (256, 16) | ||||||||||||||||||||
// Change these parameters will cause changes in kernel | ||||||||||||||||||||
jianyizh marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
const scalar_t* dY_data = dY.const_data_ptr<scalar_t>(); | ||||||||||||||||||||
const scalar_t* X_data = X.const_data_ptr<scalar_t>(); | ||||||||||||||||||||
weight_t* dg_data = | ||||||||||||||||||||
dgamma.defined() ? dgamma.data_ptr<weight_t>() : nullptr; | ||||||||||||||||||||
weight_t* db_data = dbeta.defined() ? dbeta.data_ptr<weight_t>() : nullptr; | ||||||||||||||||||||
Tensor dgamma_blocks; | ||||||||||||||||||||
Tensor dbeta_blocks; | ||||||||||||||||||||
weight_t* dgamma_blocks_ptr = nullptr; | ||||||||||||||||||||
weight_t* dbeta_blocks_ptr = nullptr; | ||||||||||||||||||||
if (dgamma.defined()) { | ||||||||||||||||||||
auto options = dgamma.options(); | ||||||||||||||||||||
// TODO: how to set dgamma_blocks dtype = float32? | ||||||||||||||||||||
dgamma_blocks = at::empty({num_tile_m, N}, options); | ||||||||||||||||||||
dgamma_blocks_ptr = dgamma_blocks.data_ptr<weight_t>(); | ||||||||||||||||||||
} | ||||||||||||||||||||
if (dbeta.defined()) { | ||||||||||||||||||||
auto options = dbeta.options(); | ||||||||||||||||||||
dbeta_blocks = at::empty({num_tile_m, N}, options); | ||||||||||||||||||||
dbeta_blocks_ptr = dbeta_blocks.data_ptr<weight_t>(); | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This TODO comment suggests uncertainty about the data type handling. The comment should either be resolved or provide more context about why float32 might be needed and what the current behavior is.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
size_t num_workgroup = | ||||||||||||||||||||
std::min(num_tile_m * num_tile_n, static_cast<int>(thread_slots / local_size_x)); | ||||||||||||||||||||
if (dgamma.defined() && dbeta.defined()) { | ||||||||||||||||||||
GammaBetaReduceFunctor< | ||||||||||||||||||||
scalar_t, | ||||||||||||||||||||
accscalar_t, | ||||||||||||||||||||
mean_t, | ||||||||||||||||||||
weight_t, | ||||||||||||||||||||
true, | ||||||||||||||||||||
true> | ||||||||||||||||||||
kfn(mean_data, | ||||||||||||||||||||
var_data, | ||||||||||||||||||||
dY_data, | ||||||||||||||||||||
X_data, | ||||||||||||||||||||
dgamma_blocks_ptr, | ||||||||||||||||||||
dbeta_blocks_ptr, | ||||||||||||||||||||
num_tile_m, | ||||||||||||||||||||
num_tile_n, | ||||||||||||||||||||
tile_size_m, | ||||||||||||||||||||
tile_size_n, | ||||||||||||||||||||
elements_per_thread, | ||||||||||||||||||||
local_size_x, | ||||||||||||||||||||
M, | ||||||||||||||||||||
N); | ||||||||||||||||||||
|
||||||||||||||||||||
sycl_kernel_submit< | ||||||||||||||||||||
GammaBetaReduceFunctor< | ||||||||||||||||||||
scalar_t, | ||||||||||||||||||||
accscalar_t, | ||||||||||||||||||||
mean_t, | ||||||||||||||||||||
weight_t, | ||||||||||||||||||||
true, | ||||||||||||||||||||
true>, | ||||||||||||||||||||
3>( | ||||||||||||||||||||
{num_workgroup, | ||||||||||||||||||||
local_size_x, | ||||||||||||||||||||
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)}, | ||||||||||||||||||||
{1, | ||||||||||||||||||||
local_size_x, | ||||||||||||||||||||
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)}, | ||||||||||||||||||||
getCurrentSYCLQueue(), | ||||||||||||||||||||
kfn); | ||||||||||||||||||||
dgamma = dgamma_blocks.sum(0); | ||||||||||||||||||||
dbeta = dbeta_blocks.sum(0); | ||||||||||||||||||||
} else if (dgamma.defined() && !dbeta.defined()) { | ||||||||||||||||||||
GammaBetaReduceFunctor< | ||||||||||||||||||||
scalar_t, | ||||||||||||||||||||
accscalar_t, | ||||||||||||||||||||
mean_t, | ||||||||||||||||||||
weight_t, | ||||||||||||||||||||
true, | ||||||||||||||||||||
false> | ||||||||||||||||||||
kfn(mean_data, | ||||||||||||||||||||
var_data, | ||||||||||||||||||||
dY_data, | ||||||||||||||||||||
X_data, | ||||||||||||||||||||
dgamma_blocks_ptr, | ||||||||||||||||||||
dbeta_blocks_ptr, | ||||||||||||||||||||
num_tile_m, | ||||||||||||||||||||
num_tile_n, | ||||||||||||||||||||
tile_size_m, | ||||||||||||||||||||
tile_size_n, | ||||||||||||||||||||
elements_per_thread, | ||||||||||||||||||||
local_size_x, | ||||||||||||||||||||
M, | ||||||||||||||||||||
N); | ||||||||||||||||||||
|
||||||||||||||||||||
sycl_kernel_submit< | ||||||||||||||||||||
GammaBetaReduceFunctor< | ||||||||||||||||||||
scalar_t, | ||||||||||||||||||||
accscalar_t, | ||||||||||||||||||||
mean_t, | ||||||||||||||||||||
weight_t, | ||||||||||||||||||||
true, | ||||||||||||||||||||
false>, | ||||||||||||||||||||
3>( | ||||||||||||||||||||
{num_workgroup, | ||||||||||||||||||||
local_size_x, | ||||||||||||||||||||
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)}, | ||||||||||||||||||||
{1, | ||||||||||||||||||||
local_size_x, | ||||||||||||||||||||
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)}, | ||||||||||||||||||||
getCurrentSYCLQueue(), | ||||||||||||||||||||
kfn); | ||||||||||||||||||||
dgamma = dgamma_blocks.sum(0); | ||||||||||||||||||||
} else if (!dgamma.defined() && dbeta.defined()) { | ||||||||||||||||||||
GammaBetaReduceFunctor< | ||||||||||||||||||||
scalar_t, | ||||||||||||||||||||
accscalar_t, | ||||||||||||||||||||
mean_t, | ||||||||||||||||||||
weight_t, | ||||||||||||||||||||
false, | ||||||||||||||||||||
true> | ||||||||||||||||||||
kfn(mean_data, | ||||||||||||||||||||
var_data, | ||||||||||||||||||||
dY_data, | ||||||||||||||||||||
X_data, | ||||||||||||||||||||
dgamma_blocks_ptr, | ||||||||||||||||||||
dbeta_blocks_ptr, | ||||||||||||||||||||
num_tile_m, | ||||||||||||||||||||
num_tile_n, | ||||||||||||||||||||
tile_size_m, | ||||||||||||||||||||
tile_size_n, | ||||||||||||||||||||
elements_per_thread, | ||||||||||||||||||||
local_size_x, | ||||||||||||||||||||
M, | ||||||||||||||||||||
N); | ||||||||||||||||||||
|
||||||||||||||||||||
sycl_kernel_submit< | ||||||||||||||||||||
GammaBetaReduceFunctor< | ||||||||||||||||||||
scalar_t, | ||||||||||||||||||||
accscalar_t, | ||||||||||||||||||||
mean_t, | ||||||||||||||||||||
weight_t, | ||||||||||||||||||||
false, | ||||||||||||||||||||
true>, | ||||||||||||||||||||
3>( | ||||||||||||||||||||
{num_workgroup, | ||||||||||||||||||||
local_size_x, | ||||||||||||||||||||
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)}, | ||||||||||||||||||||
{1, | ||||||||||||||||||||
local_size_x, | ||||||||||||||||||||
static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)}, | ||||||||||||||||||||
getCurrentSYCLQueue(), | ||||||||||||||||||||
kfn); | ||||||||||||||||||||
dbeta = dbeta_blocks.sum(0); | ||||||||||||||||||||
} else { | ||||||||||||||||||||
return; | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
} else { | ||||||||||||||||||||
gamma_beta_bwd_simple_kernel<scalar_t, accscalar_t, mean_t, weight_t>( | ||||||||||||||||||||
dY, X, mean_data, var_data, dgamma, dbeta, config_w); | ||||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
void layer_norm_kernel( | ||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.