@@ -27,17 +27,18 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
2727 // sum up partial sums
2828 mean_var = warp_reduce_sum (mean_var, item_ct1);
2929 if (block_size > WARP_SIZE) {
30- int warp_id = tid / WARP_SIZE;
31- int lane_id = tid % WARP_SIZE;
32- if (lane_id == 0 ) {
33- s_sum[warp_id] = mean_var;
30+ const auto sub_group = item_ct1.get_sub_group ();
31+ const auto sg_id = sub_group.get_group_linear_id ();
32+ const auto wi_in_sg = sub_group.get_local_linear_id ();
33+ if (wi_in_sg == 0 ) {
34+ s_sum[sg_id] = mean_var;
3435 }
3536 item_ct1.barrier (sycl::access::fence_space::local_space);
3637 mean_var = 0 .f ;
37- size_t nreduce = nwarps / WARP_SIZE;
38+ const size_t nreduce = ( nwarps + WARP_SIZE - 1 ) / WARP_SIZE;
3839 for (size_t i = 0 ; i < nreduce; i += 1 )
3940 {
40- mean_var += s_sum[lane_id + i * WARP_SIZE];
41+ mean_var += s_sum[wi_in_sg + i * WARP_SIZE];
4142 }
4243 mean_var = warp_reduce_sum (mean_var, item_ct1);
4344 }
@@ -165,19 +166,19 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
165166 // sum up partial sums
166167 tmp = warp_reduce_sum (tmp, item_ct1);
167168 if (block_size > WARP_SIZE) {
168-
169- int warp_id = item_ct1. get_local_id ( 2 ) / WARP_SIZE ;
170- int lane_id = item_ct1. get_local_id ( 2 ) % WARP_SIZE ;
171- if (lane_id == 0 ) {
172- s_sum[warp_id ] = tmp;
169+ const auto sub_group = item_ct1. get_sub_group ();
170+ const auto sg_id = sub_group. get_group_linear_id () ;
171+ const auto wi_in_sg = sub_group. get_local_linear_id () ;
172+ if (wi_in_sg == 0 ) {
173+ s_sum[sg_id ] = tmp;
173174 }
174175
175176 item_ct1.barrier (sycl::access::fence_space::local_space);
176- size_t nreduce = nwarps / WARP_SIZE;
177+ const size_t nreduce = ( nwarps + WARP_SIZE - 1 ) / WARP_SIZE;
177178 tmp = 0 .f ;
178179 for (size_t i = 0 ; i < nreduce; i += 1 )
179180 {
180- tmp += s_sum[lane_id + i * WARP_SIZE];
181+ tmp += s_sum[wi_in_sg + i * WARP_SIZE];
181182 }
182183 tmp = warp_reduce_sum (tmp, item_ct1);
183184 }
0 commit comments