Skip to content

Commit 7c9721f

Browse files
committed
address review comments: change it to more like SYCL
1 parent 545e47d commit 7c9721f

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

ggml/src/ggml-sycl/norm.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,18 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
2727
// sum up partial sums
2828
mean_var = warp_reduce_sum(mean_var, item_ct1);
2929
if (block_size > WARP_SIZE) {
30-
int warp_id = tid / WARP_SIZE;
31-
int lane_id = tid % WARP_SIZE;
32-
if (lane_id == 0) {
33-
s_sum[warp_id] = mean_var;
30+
const auto sub_group = item_ct1.get_sub_group();
31+
const auto sg_id = sub_group.get_group_linear_id();
32+
const auto wi_in_sg = sub_group.get_local_linear_id();
33+
if (wi_in_sg == 0) {
34+
s_sum[sg_id] = mean_var;
3435
}
3536
item_ct1.barrier(sycl::access::fence_space::local_space);
3637
mean_var = 0.f;
37-
size_t nreduce = nwarps / WARP_SIZE;
38+
const size_t nreduce = (nwarps + WARP_SIZE - 1) / WARP_SIZE;
3839
for (size_t i = 0; i < nreduce; i += 1)
3940
{
40-
mean_var += s_sum[lane_id + i * WARP_SIZE];
41+
mean_var += s_sum[wi_in_sg + i * WARP_SIZE];
4142
}
4243
mean_var = warp_reduce_sum(mean_var, item_ct1);
4344
}
@@ -165,19 +166,19 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
165166
// sum up partial sums
166167
tmp = warp_reduce_sum(tmp, item_ct1);
167168
if (block_size > WARP_SIZE) {
168-
169-
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
170-
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
171-
if (lane_id == 0) {
172-
s_sum[warp_id] = tmp;
169+
const auto sub_group = item_ct1.get_sub_group();
170+
const auto sg_id = sub_group.get_group_linear_id();
171+
const auto wi_in_sg = sub_group.get_local_linear_id();
172+
if (wi_in_sg == 0) {
173+
s_sum[sg_id] = tmp;
173174
}
174175

175176
item_ct1.barrier(sycl::access::fence_space::local_space);
176-
size_t nreduce = nwarps / WARP_SIZE;
177+
const size_t nreduce = (nwarps + WARP_SIZE - 1) / WARP_SIZE;
177178
tmp = 0.f;
178179
for (size_t i = 0; i < nreduce; i += 1)
179180
{
180-
tmp += s_sum[lane_id + i * WARP_SIZE];
181+
tmp += s_sum[wi_in_sg + i * WARP_SIZE];
181182
}
182183
tmp = warp_reduce_sum(tmp, item_ct1);
183184
}

0 commit comments

Comments
 (0)