@@ -5,11 +5,13 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
55
66 const int nrows = item_ct1.get_group_range (2 );
77 const int nchannels = item_ct1.get_group_range (1 );
8+ const int nthreads = item_ct1.get_local_range (2 );
89 const int sample = item_ct1.get_group (0 );
910 const int channel = item_ct1.get_group (1 );
1011 const int row = item_ct1.get_group (2 );
1112
1213 const int tid = item_ct1.get_local_id (2 );
14+ const int nwarps = nthreads / WARP_SIZE;
1315
1416 x += sample * stride_sample + channel * stride_channel + row * stride_row;
1517 dst += ((sample * nchannels + channel) * nrows + row) * ncols;
@@ -30,8 +32,12 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
3032 s_sum[warp_id] = mean_var;
3133 }
3234 item_ct1.barrier (sycl::access::fence_space::local_space);
33-
34- mean_var = s_sum[lane_id];
35+ mean_var = 0 .f ;
36+ size_t nreduce = nwarps / WARP_SIZE;
37+ for (size_t i = 0 ; i < nreduce; i += 1 )
38+ {
39+ mean_var += s_sum[lane_id + i * WARP_SIZE];
40+ }
3541 mean_var = warp_reduce_sum (mean_var, item_ct1);
3642 }
3743
@@ -139,8 +145,10 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
139145 const int sample = item_ct1.get_group (0 );
140146 const int channel = item_ct1.get_group (1 );
141147 const int row = item_ct1.get_group (2 );
148+ const int nthreads = item_ct1.get_local_range (2 );
142149
143150 const int tid = item_ct1.get_local_id (2 );
151+ const int nwarps = nthreads / WARP_SIZE;
144152
145153 x += sample*stride_sample + channel*stride_channel + row*stride_row;
146154 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
@@ -164,7 +172,12 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
164172 }
165173
166174 item_ct1.barrier (sycl::access::fence_space::local_space);
167- tmp = s_sum[lane_id];
175+ size_t nreduce = nwarps / WARP_SIZE;
176+ tmp = 0 .f ;
177+ for (size_t i = 0 ; i < nreduce; i += 1 )
178+ {
179+ tmp += s_sum[lane_id + i * WARP_SIZE];
180+ }
168181 tmp = warp_reduce_sum (tmp, item_ct1);
169182 }
170183
0 commit comments