Skip to content

Commit c436b11

Browse files
committed
save
1 parent 484ffa6 commit c436b11

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/ATen/native/xpu/sycl/LayerNormKernels.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -709,13 +709,13 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
709709
local_sum_gamma_[(i + 1) * tile_size_n_ + local_n];
710710
}
711711
if (db_data_ != nullptr)
712-
db_data_[tile_id_m * tile_size_n_ + actual_column] =
712+
db_data_[tile_id_m * N_ + actual_column] =
713713
(static_cast<weight_t>(output_sum_beta[0]) +
714714
static_cast<weight_t>(output_sum_beta[1])) +
715715
(static_cast<weight_t>(output_sum_beta[2]) +
716716
static_cast<weight_t>(output_sum_beta[3]));
717717
if (dg_data_ != nullptr)
718-
dg_data_[tile_id_m * tile_size_n_ + actual_column] =
718+
dg_data_[tile_id_m * N_ + actual_column] =
719719
(static_cast<weight_t>(output_sum_gamma[0]) +
720720
static_cast<weight_t>(output_sum_gamma[1])) +
721721
(static_cast<weight_t>(output_sum_gamma[2]) +
@@ -1151,8 +1151,8 @@ void _layer_norm_backward_kernel(
11511151
sycl_kernel_submit<
11521152
GammaBetaReduceFunctor<scalar_t, accscalar_t, mean_t, weight_t>,
11531153
3>(
1154-
{num_workgroup, local_size_x, static_cast<size_t>(N < SIMD ? N : SIMD)},
1155-
{1, local_size_x, static_cast<size_t>(N < SIMD ? N : SIMD)},
1154+
{num_workgroup, local_size_x, static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
1155+
{1, local_size_x, static_cast<size_t>(tile_size_n < SIMD ? tile_size_n : SIMD)},
11561156
getCurrentSYCLQueue(),
11571157
kfn);
11581158
dgamma = dgamma_blocks.sum(0);

0 commit comments

Comments
 (0)