@@ -709,13 +709,13 @@ struct GammaBetaReduceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
709
709
local_sum_gamma_[(i + 1 ) * tile_size_n_ + local_n];
710
710
}
711
711
if (db_data_ != nullptr )
712
- db_data_[tile_id_m * tile_size_n_ + actual_column] =
712
+ db_data_[tile_id_m * N_ + actual_column] =
713
713
(static_cast <weight_t >(output_sum_beta[0 ]) +
714
714
static_cast <weight_t >(output_sum_beta[1 ])) +
715
715
(static_cast <weight_t >(output_sum_beta[2 ]) +
716
716
static_cast <weight_t >(output_sum_beta[3 ]));
717
717
if (dg_data_ != nullptr )
718
- dg_data_[tile_id_m * tile_size_n_ + actual_column] =
718
+ dg_data_[tile_id_m * N_ + actual_column] =
719
719
(static_cast <weight_t >(output_sum_gamma[0 ]) +
720
720
static_cast <weight_t >(output_sum_gamma[1 ])) +
721
721
(static_cast <weight_t >(output_sum_gamma[2 ]) +
@@ -1151,8 +1151,8 @@ void _layer_norm_backward_kernel(
1151
1151
sycl_kernel_submit<
1152
1152
GammaBetaReduceFunctor<scalar_t , accscalar_t , mean_t , weight_t >,
1153
1153
3 >(
1154
- {num_workgroup, local_size_x, static_cast <size_t >(N < SIMD ? N : SIMD)},
1155
- {1 , local_size_x, static_cast <size_t >(N < SIMD ? N : SIMD)},
1154
+ {num_workgroup, local_size_x, static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1155
+ {1 , local_size_x, static_cast <size_t >(tile_size_n < SIMD ? tile_size_n : SIMD)},
1156
1156
getCurrentSYCLQueue (),
1157
1157
kfn);
1158
1158
dgamma = dgamma_blocks.sum (0 );
0 commit comments