diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 2d62ad0584..4ff0ada5eb 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -599,8 +599,8 @@ void _layer_norm_kernel( beta.defined() ? can_vectorize(beta_data, alignment) : true; if ((std::is_same_v || std::is_same_v || - std::is_same_v)&&N <= - static_cast(1ULL << std::numeric_limits::digits) && + std::is_same_v) && + N <= static_cast(1ULL << std::numeric_limits::digits) && N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) { launch_vectorized_layer_norm_kernel( @@ -620,6 +620,157 @@ void _layer_norm_kernel( } } +template < + typename scalar_t, + typename accscalar_t, + typename mean_t, + typename weight_t, + bool have_gamma = true, + bool have_beta = true> +struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<3> item) const { + auto local_n = item.get_local_id(2); // [0, 32) + auto local_m = item.get_local_id(1); // [0, 8) + for (auto tile_id = item.get_global_id(0); + tile_id < num_tile_n_ * num_tile_m_; + tile_id += item.get_group_range(0)) { + auto tile_id_n = tile_id % num_tile_n_; + auto tile_id_m = tile_id / num_tile_n_; + auto tile_actual_row_base = tile_id_m * tile_size_m_; + auto tile_actual_col_base = tile_id_n * tile_size_n_; + auto actual_column = tile_actual_col_base + local_n; + if (actual_column < N_) { + // slm_row 0, 8, 16...56 + for (auto slm_row = 0; slm_row < tile_size_m_ / elements_per_thread_; + slm_row += num_subgroup_) { + accscalar_t sum_beta = accscalar_t(0); + accscalar_t sum_gamma = accscalar_t(0); + // row 0, 128, 256, ...896 + auto row = tile_actual_row_base + slm_row * elements_per_thread_; + for (int i = 0; i < elements_per_thread_; i++) { + // row_local: row + 0, 8, 16, ...120 + auto row_local = row + i * num_subgroup_; + auto actual_row = row_local + local_m; + // TODO: try tree reduction here if accuracy loss + if (actual_row < M_) { + if constexpr (have_beta) { + sum_beta += static_cast( + dY_data_[actual_row * N_ + actual_column]); + } + if constexpr (have_gamma) { + sum_gamma += static_cast( + dY_data_[actual_row * N_ + actual_column]) * + (static_cast( + X_data_[actual_row * N_ + actual_column]) - + static_cast(mean_data_[actual_row])) * + static_cast(var_data_[actual_row]); + } + } + } + + local_sum_beta_[(slm_row + local_m) * tile_size_n_ + local_n] = + sum_beta; + local_sum_gamma_[(slm_row + local_m) * tile_size_n_ + local_n] = + sum_gamma; + } + + // item.barrier(sycl_local_fence); + accscalar_t slm_sum_beta = accscalar_t(0); + accscalar_t slm_sum_gamma = accscalar_t(0); + // slm row 64, 8 subgroup, i = 0,2,4,6 + // slm row 32, 8 subgroup, i = 0,2 + // slm row 16, 8 subgroup, i = 0 + for (int i = 0; i < tile_size_m_ / elements_per_thread_ / num_subgroup_; + i = i + 1) { + slm_sum_beta += local_sum_beta_ + [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n]; + slm_sum_gamma += local_sum_gamma_ + [(i * num_subgroup_ + local_m) * tile_size_n_ + local_n]; + } + local_sum_beta_[local_m * tile_size_n_ + local_n] = slm_sum_beta; + local_sum_gamma_[local_m * tile_size_n_ + local_n] = slm_sum_gamma; + } + item.barrier(sycl_local_fence); + accscalar_t output_sum_beta = accscalar_t(0); + accscalar_t output_sum_gamma = accscalar_t(0); + if (local_m == 0 && actual_column < N_) { + for (int i = 0; i < num_subgroup_; i = i + 1) { + output_sum_beta += local_sum_beta_[i * tile_size_n_ + local_n]; + output_sum_gamma += local_sum_gamma_[i * tile_size_n_ + local_n]; + } + if constexpr (have_beta) { + db_data_[tile_id_m * N_ + actual_column] = + static_cast(output_sum_beta); + } + + if constexpr (have_gamma) { + dg_data_[tile_id_m * N_ + actual_column] = + static_cast(output_sum_gamma); + } + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + local_sum_beta_ = sycl_local_acc_t( + sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_), + cgh); + local_sum_gamma_ = sycl_local_acc_t( + sycl::range<1>(tile_size_n_ * tile_size_m_ / elements_per_thread_), + cgh); + } + + GammaBetaReduceFunctor( + const mean_t* mean_data, + const mean_t* var_data, + const scalar_t* dY_data, + const scalar_t* X_data, + weight_t* dg_block_data, + weight_t* db_block_data, + int64_t num_tile_m, + int64_t num_tile_n, + int64_t tile_size_m, + int64_t tile_size_n, + int64_t elements_per_thread, + int64_t num_subgroup, + int64_t M, + int64_t N) + : mean_data_(mean_data), + var_data_(var_data), + dY_data_(dY_data), + X_data_(X_data), + dg_data_(dg_block_data), + db_data_(db_block_data), + num_tile_m_(num_tile_m), + num_tile_n_(num_tile_n), + tile_size_m_(tile_size_m), + tile_size_n_(tile_size_n), + elements_per_thread_(elements_per_thread), + num_subgroup_(num_subgroup), + M_(M), + N_(N), + local_sum_beta_(), + local_sum_gamma_() {} + + private: + const mean_t* mean_data_; + const mean_t* var_data_; + const scalar_t* dY_data_; + const scalar_t* X_data_; + weight_t* dg_data_; + weight_t* db_data_; + int64_t num_tile_m_; + int64_t num_tile_n_; + int64_t tile_size_m_; + int64_t tile_size_n_; + int64_t elements_per_thread_; + int64_t num_subgroup_; + int64_t M_; + int64_t N_; + sycl_local_acc_t local_sum_beta_; + sycl_local_acc_t local_sum_gamma_; +}; + template < typename scalar_t, typename accscalar_t, @@ -900,10 +1051,209 @@ void _layer_norm_backward_kernel( norm, config, can_use_32bit_index); } } - auto config_w = NormConfig(M, N, 0, sizeof(scalar_t)); - gamma_beta_bwd_simple_kernel( - dY, X, mean_data, var_data, dgamma, dbeta, config_w); + auto norm_config_global_size = + config_w.workgroup_num * config_w.block_row * config_w.workgroup_size; + int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); + // use two stage col reduction if norm config occupancy < 50% + // TODO: we can relax this restriction in future for better perf + bool use_two_stage_col_reduction = + (dY.dtype() == kFloat || dY.dtype() == kBFloat16 || + dY.dtype() == kHalf) && + norm_config_global_size / syclMaxSubGroupSize() * 2 <= thread_slots; + // cuda uses condition M > 64 * 1024 && N / 32 < sm_count / 2 to parallelize + // in the M dimension + if (use_two_stage_col_reduction && M > 64 * 1024 && + N / 32 < syclGpuEuCount() / syclGpuEUCountPerSubslice() / 2) { + const size_t local_size_x = 8; + const size_t SIMD = 32; + // workgroup size is 256 + // slm is 16KB, 64*32 float * 2 + // elements_per_thread is at least 16 + const int elements_per_thread = 16; + int tile_size_m = 1024; + int tile_size_n = N < 32 ? N : 32; + int num_tile_m = (M + tile_size_m - 1) / tile_size_m; + int num_tile_n = (N + tile_size_n - 1) / tile_size_n; + bool adjust_m = true; + // for M = 64*1024, N = 1, we choose tile size (256, 16) on pvc + // TODO: Consider tuning the tile size selection logic (tile_size_m, tile_size_n) and occupancy calculation + for (auto i = 0; i < 3; i++) { + // occupancy <= 50% + if (num_tile_m * num_tile_n * local_size_x * SIMD / + syclMaxSubGroupSize() * 2 <= + thread_slots) { + if (adjust_m) { + tile_size_m /= 2; + num_tile_m = (M + tile_size_m - 1) / tile_size_m; + adjust_m = false; + } else { + tile_size_n /= 2; + num_tile_n = (N + tile_size_n - 1) / tile_size_n; + adjust_m = true; + } + } else { + break; + } + } + // tile size can be (1024,32), (512,32), (512,16), (256, 16) + // Modifying these parameters (num_subgroup, workgroup_size, tile_size, elements_per_thread) + // will alter the kernel configuration, potentially affecting performance and behavior. + const scalar_t* dY_data = dY.const_data_ptr(); + const scalar_t* X_data = X.const_data_ptr(); + weight_t* dg_data = + dgamma.defined() ? dgamma.data_ptr() : nullptr; + weight_t* db_data = dbeta.defined() ? dbeta.data_ptr() : nullptr; + Tensor dgamma_blocks; + Tensor dbeta_blocks; + weight_t* dgamma_blocks_ptr = nullptr; + weight_t* dbeta_blocks_ptr = nullptr; + if (dgamma.defined()) { + auto options = dgamma.options(); + // TODO: how to set dgamma_blocks dtype = float32? + dgamma_blocks = at::empty({num_tile_m, N}, options); + dgamma_blocks_ptr = dgamma_blocks.data_ptr(); + } + if (dbeta.defined()) { + auto options = dbeta.options(); + dbeta_blocks = at::empty({num_tile_m, N}, options); + dbeta_blocks_ptr = dbeta_blocks.data_ptr(); + } + + size_t num_workgroup = + std::min(num_tile_m * num_tile_n, static_cast(thread_slots / local_size_x)); + if (dgamma.defined() && dbeta.defined()) { + GammaBetaReduceFunctor< + scalar_t, + accscalar_t, + mean_t, + weight_t, + true, + true> + kfn(mean_data, + var_data, + dY_data, + X_data, + dgamma_blocks_ptr, + dbeta_blocks_ptr, + num_tile_m, + num_tile_n, + tile_size_m, + tile_size_n, + elements_per_thread, + local_size_x, + M, + N); + + sycl_kernel_submit< + GammaBetaReduceFunctor< + scalar_t, + accscalar_t, + mean_t, + weight_t, + true, + true>, + 3>( + {num_workgroup, + local_size_x, + static_cast(tile_size_n < SIMD ? tile_size_n : SIMD)}, + {1, + local_size_x, + static_cast(tile_size_n < SIMD ? tile_size_n : SIMD)}, + getCurrentSYCLQueue(), + kfn); + dgamma = dgamma_blocks.sum(0); + dbeta = dbeta_blocks.sum(0); + } else if (dgamma.defined() && !dbeta.defined()) { + GammaBetaReduceFunctor< + scalar_t, + accscalar_t, + mean_t, + weight_t, + true, + false> + kfn(mean_data, + var_data, + dY_data, + X_data, + dgamma_blocks_ptr, + dbeta_blocks_ptr, + num_tile_m, + num_tile_n, + tile_size_m, + tile_size_n, + elements_per_thread, + local_size_x, + M, + N); + + sycl_kernel_submit< + GammaBetaReduceFunctor< + scalar_t, + accscalar_t, + mean_t, + weight_t, + true, + false>, + 3>( + {num_workgroup, + local_size_x, + static_cast(tile_size_n < SIMD ? tile_size_n : SIMD)}, + {1, + local_size_x, + static_cast(tile_size_n < SIMD ? tile_size_n : SIMD)}, + getCurrentSYCLQueue(), + kfn); + dgamma = dgamma_blocks.sum(0); + } else if (!dgamma.defined() && dbeta.defined()) { + GammaBetaReduceFunctor< + scalar_t, + accscalar_t, + mean_t, + weight_t, + false, + true> + kfn(mean_data, + var_data, + dY_data, + X_data, + dgamma_blocks_ptr, + dbeta_blocks_ptr, + num_tile_m, + num_tile_n, + tile_size_m, + tile_size_n, + elements_per_thread, + local_size_x, + M, + N); + + sycl_kernel_submit< + GammaBetaReduceFunctor< + scalar_t, + accscalar_t, + mean_t, + weight_t, + false, + true>, + 3>( + {num_workgroup, + local_size_x, + static_cast(tile_size_n < SIMD ? tile_size_n : SIMD)}, + {1, + local_size_x, + static_cast(tile_size_n < SIMD ? tile_size_n : SIMD)}, + getCurrentSYCLQueue(), + kfn); + dbeta = dbeta_blocks.sum(0); + } else { + return; + } + + } else { + gamma_beta_bwd_simple_kernel( + dY, X, mean_data, var_data, dgamma, dbeta, config_w); + } } void layer_norm_kernel(