Skip to content

Commit 9129362

Browse files
committed
SYCL softmax: Initialize nreduce as size_t
1 parent fe5afd4 commit 9129362

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

ggml/src/ggml-sycl/softmax.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
1616
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
1717
const int nthreads = block_size;
1818
const int nwarps = nthreads / WARP_SIZE;
19-
int nreduce = nwarps / WARP_SIZE;
19+
size_t nreduce = nwarps / WARP_SIZE;
2020
float slope = 1.0f;
2121

2222
// ALiBi
@@ -53,7 +53,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
5353
if (block_size > WARP_SIZE) {
5454
if (warp_id == 0) {
5555
buf[lane_id] = -INFINITY;
56-
for (size_t i = 1; i < (size_t) nreduce; i += 1) {
56+
for (size_t i = 1; i < nreduce; i += 1) {
5757
buf[lane_id + i * WARP_SIZE] = -INFINITY;
5858
}
5959
}
@@ -64,7 +64,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
6464
}
6565
item_ct1.barrier(sycl::access::fence_space::local_space);
6666
max_val = buf[lane_id];
67-
for (size_t i = 1; i < (size_t) nreduce; i += 1) {
67+
for (size_t i = 1; i < nreduce; i += 1) {
6868
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
6969
}
7070
max_val = warp_reduce_max(max_val, item_ct1);
@@ -89,7 +89,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
8989
item_ct1.barrier(sycl::access::fence_space::local_space);
9090
if (warp_id == 0) {
9191
buf[lane_id] = 0.f;
92-
for (size_t i = 1; i < (size_t) nreduce; i += 1) {
92+
for (size_t i = 1; i < nreduce; i += 1) {
9393
buf[lane_id + i * WARP_SIZE] = 0.f;
9494
}
9595
}
@@ -101,7 +101,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
101101
item_ct1.barrier(sycl::access::fence_space::local_space);
102102

103103
tmp = buf[lane_id];
104-
for (size_t i = 1; i < (size_t) nreduce; i += 1) {
104+
for (size_t i = 1; i < nreduce; i += 1) {
105105
tmp += buf[lane_id + i * WARP_SIZE];
106106
}
107107
tmp = warp_reduce_sum(tmp, item_ct1);

0 commit comments

Comments
 (0)