Skip to content

Commit 32164aa

Browse files
committed
Initialize nreduce as size_t
1 parent fb2e66e commit 32164aa

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ggml/src/ggml-sycl/norm.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
3131
*/
3232
item_ct1.barrier(sycl::access::fence_space::local_space);
3333
mean_var = 0.f;
34-
int nreduce = nwarps / WARP_SIZE;
35-
for (size_t i = 0; i < (size_t) nreduce; i += 1)
34+
size_t nreduce = nwarps / WARP_SIZE;
35+
for (size_t i = 0; i < nreduce; i += 1)
3636
{
3737
mean_var += s_sum[lane_id + i * WARP_SIZE];
3838
}
@@ -55,7 +55,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
5555
const int nthreads = item_ct1.get_local_range(2);
5656
const int nwarps = nthreads / WARP_SIZE;
5757
start += item_ct1.get_local_id(2);
58-
int nreduce = nwarps / WARP_SIZE;
58+
size_t nreduce = nwarps / WARP_SIZE;
5959

6060
if (end >= ne_elements) {
6161
end = ne_elements;
@@ -86,7 +86,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
8686
*/
8787
item_ct1.barrier();
8888
tmp = 0.f;
89-
for (size_t i = 0; i < (size_t) nreduce; i += 1)
89+
for (size_t i = 0; i < nreduce; i += 1)
9090
{
9191
tmp += s_sum[lane_id + i * WARP_SIZE];
9292
}
@@ -121,7 +121,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
121121
*/
122122
item_ct1.barrier();
123123
tmp = 0.f;
124-
for (size_t i = 0; i < (size_t) nreduce; i += 1)
124+
for (size_t i = 0; i < nreduce; i += 1)
125125
{
126126
tmp += s_sum[lane_id + i * WARP_SIZE];
127127
}
@@ -163,7 +163,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
163163
converged control flow. You may need to adjust the code.
164164
*/
165165
item_ct1.barrier(sycl::access::fence_space::local_space);
166-
int nreduce = nwarps / WARP_SIZE;
166+
size_t nreduce = nwarps / WARP_SIZE;
167167
tmp = 0.f;
168168
for (size_t i = 0; i < nreduce; i += 1)
169169
{

0 commit comments

Comments
 (0)