Skip to content

Commit 7b8568f

Browse files
committed
opencl: improve workgroup size for rms_norm_mul
1 parent 4e524fd commit 7b8568f

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4563,13 +4563,6 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
45634563

45644564
GGML_ASSERT(ne00 % 4 == 0);
45654565

4566-
const int nth = MIN(64, ne00);
4567-
4568-
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
4569-
size_t local_work_size[] = {(size_t)nth, 1, 1};
4570-
4571-
cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
4572-
45734566
size_t sgs;
45744567
if (backend_ctx->gpu_family == ADRENO) {
45754568
sgs = 64;
@@ -4579,6 +4572,19 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
45794572
GGML_ASSERT(false && "Unsupported GPU");
45804573
}
45814574

4575+
cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
4576+
4577+
int nth = sgs;
4578+
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
4579+
while (nth < ne00 && nth < max_workgroup_size) {
4580+
nth *= 2;
4581+
}
4582+
nth = MIN(nth, max_workgroup_size);
4583+
nth = MIN(nth, ne00);
4584+
4585+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
4586+
size_t local_work_size[] = {(size_t)nth, 1, 1};
4587+
45824588
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
45834589
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
45844590
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));

0 commit comments

Comments
 (0)