@@ -31,8 +31,8 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
3131 */
3232 item_ct1.barrier (sycl::access::fence_space::local_space);
3333 mean_var = 0 .f ;
34- int nreduce = nwarps / WARP_SIZE;
35- for (size_t i = 0 ; i < ( size_t ) nreduce; i += 1 )
34+ size_t nreduce = nwarps / WARP_SIZE;
35+ for (size_t i = 0 ; i < nreduce; i += 1 )
3636 {
3737 mean_var += s_sum[lane_id + i * WARP_SIZE];
3838 }
@@ -55,7 +55,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
5555 const int nthreads = item_ct1.get_local_range (2 );
5656 const int nwarps = nthreads / WARP_SIZE;
5757 start += item_ct1.get_local_id (2 );
58- int nreduce = nwarps / WARP_SIZE;
58+ size_t nreduce = nwarps / WARP_SIZE;
5959
6060 if (end >= ne_elements) {
6161 end = ne_elements;
@@ -86,7 +86,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
8686 */
8787 item_ct1.barrier ();
8888 tmp = 0 .f ;
89- for (size_t i = 0 ; i < ( size_t ) nreduce; i += 1 )
89+ for (size_t i = 0 ; i < nreduce; i += 1 )
9090 {
9191 tmp += s_sum[lane_id + i * WARP_SIZE];
9292 }
@@ -121,7 +121,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
121121 */
122122 item_ct1.barrier ();
123123 tmp = 0 .f ;
124- for (size_t i = 0 ; i < ( size_t ) nreduce; i += 1 )
124+ for (size_t i = 0 ; i < nreduce; i += 1 )
125125 {
126126 tmp += s_sum[lane_id + i * WARP_SIZE];
127127 }
@@ -163,7 +163,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
163163 converged control flow. You may need to adjust the code.
164164 */
165165 item_ct1.barrier (sycl::access::fence_space::local_space);
166- int nreduce = nwarps / WARP_SIZE;
166+ size_t nreduce = nwarps / WARP_SIZE;
167167 tmp = 0 .f ;
168168 for (size_t i = 0 ; i < nreduce; i += 1 )
169169 {
0 commit comments