Skip to content

Commit 52e5d42

Browse files
authored
opencl: fix rms_norm_mul (#17250)
* opencl: use subgrroup reduce for reduction in rms_norm_mul * opencl: add comment about workgroup size
1 parent 4db5641 commit 52e5d42

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5705,7 +5705,7 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
57055705
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
57065706
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
57075707
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
5708-
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
5708+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs, NULL));
57095709

57105710
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
57115711
}

ggml/src/ggml-opencl/kernels/rms_norm.cl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,15 @@ kernel void kernel_rms_norm_mul(
134134
src1 = src1 + offset1;
135135
dst = dst + offsetd;
136136

137+
// The size of sum is sizeof(float)*subgroup_size.
138+
// Each subgroup writes its partial sum to this array.
139+
// So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
140+
// This is generally true -
141+
// for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
142+
if (get_sub_group_id() == 0) {
143+
sum[get_sub_group_local_id()] = 0.0f;
144+
}
145+
137146
int i03 = get_group_id(2);
138147
int i02 = get_group_id(1);
139148
int i01 = get_group_id(0);
@@ -148,24 +157,30 @@ kernel void kernel_rms_norm_mul(
148157
sumf += dot(x[i00], x[i00]);
149158
}
150159
sumf = sub_group_reduce_add(sumf);
160+
161+
barrier(CLK_LOCAL_MEM_FENCE);
162+
151163
if (get_sub_group_local_id() == 0) {
152164
sum[get_sub_group_id()] = sumf;
153165
}
154166

155167
barrier(CLK_LOCAL_MEM_FENCE);
156168

157-
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
158-
if (get_local_id(0) < i) {
159-
sum[get_local_id(0)] += sum[get_local_id(0) + i];
160-
}
161-
}
162-
if (get_local_id(0) == 0) {
163-
sum[0] /= ne00;
164-
}
169+
//for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
170+
// if (get_local_id(0) < i) {
171+
// sum[get_local_id(0)] += sum[get_local_id(0) + i];
172+
// }
173+
//}
174+
//if (get_local_id(0) == 0) {
175+
// sum[0] /= ne00;
176+
//}
165177

166-
barrier(CLK_LOCAL_MEM_FENCE);
178+
//barrier(CLK_LOCAL_MEM_FENCE);
179+
180+
sumf = sum[get_sub_group_local_id()];
181+
sumf = sub_group_reduce_add(sumf);
167182

168-
float mean = sum[0];
183+
float mean = sumf / ne00;
169184
float scale = 1.0f/sqrt(mean + eps);
170185

171186
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);

0 commit comments

Comments
 (0)