@@ -16,7 +16,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
1616 const int lane_id = item_ct1.get_local_id (2 ) % WARP_SIZE;
1717 const int nthreads = block_size;
1818 const int nwarps = nthreads / WARP_SIZE;
19- int nreduce = nwarps / WARP_SIZE;
19+ size_t nreduce = nwarps / WARP_SIZE;
2020 float slope = 1 .0f ;
2121
2222 // ALiBi
@@ -53,7 +53,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
5353 if (block_size > WARP_SIZE) {
5454 if (warp_id == 0 ) {
5555 buf[lane_id] = -INFINITY;
56- for (size_t i = 1 ; i < ( size_t ) nreduce; i += 1 ) {
56+ for (size_t i = 1 ; i < nreduce; i += 1 ) {
5757 buf[lane_id + i * WARP_SIZE] = -INFINITY;
5858 }
5959 }
@@ -64,7 +64,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
6464 }
6565 item_ct1.barrier (sycl::access::fence_space::local_space);
6666 max_val = buf[lane_id];
67- for (size_t i = 1 ; i < ( size_t ) nreduce; i += 1 ) {
67+ for (size_t i = 1 ; i < nreduce; i += 1 ) {
6868 max_val = std::max (max_val, buf[lane_id + i * WARP_SIZE]);
6969 }
7070 max_val = warp_reduce_max (max_val, item_ct1);
@@ -89,7 +89,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
8989 item_ct1.barrier (sycl::access::fence_space::local_space);
9090 if (warp_id == 0 ) {
9191 buf[lane_id] = 0 .f ;
92- for (size_t i = 1 ; i < ( size_t ) nreduce; i += 1 ) {
92+ for (size_t i = 1 ; i < nreduce; i += 1 ) {
9393 buf[lane_id + i * WARP_SIZE] = 0 .f ;
9494 }
9595 }
@@ -101,7 +101,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
101101 item_ct1.barrier (sycl::access::fence_space::local_space);
102102
103103 tmp = buf[lane_id];
104- for (size_t i = 1 ; i < ( size_t ) nreduce; i += 1 ) {
104+ for (size_t i = 1 ; i < nreduce; i += 1 ) {
105105 tmp += buf[lane_id + i * WARP_SIZE];
106106 }
107107 tmp = warp_reduce_sum (tmp, item_ct1);
0 commit comments