Skip to content

Commit 3eca4f6

Browse files
committed
opencl: support noncontiguous rms_norm
1 parent d011e27 commit 3eca4f6

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,16 +2637,19 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
26372637
memcpy(&eps, dst->op_params, sizeof(float));
26382638

26392639
const int ne00 = src0 ? src0->ne[0] : 0;
2640+
const int ne01 = src0 ? src0->ne[1] : 0;
2641+
const int ne02 = src0 ? src0->ne[2] : 0;
2642+
const int ne03 = src0 ? src0->ne[3] : 0;
2643+
26402644
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
2645+
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
2646+
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
26412647

26422648
GGML_ASSERT(ne00 % 4 == 0);
2643-
GGML_ASSERT(ggml_is_contiguous_1(src0));
26442649

26452650
const int nth = MIN(64, ne00);
26462651

2647-
const int64_t nrows = ggml_nrows(src0);
2648-
2649-
size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
2652+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
26502653
size_t local_work_size[] = {(size_t)nth, 1, 1};
26512654

26522655
cl_kernel kernel = backend_ctx->kernel_rms_norm;
@@ -2661,15 +2664,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
26612664
sizeof(local_work_size), local_work_size,
26622665
sizeof(size_t), &sgs, NULL));
26632666

2664-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
2665-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
2666-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
2667-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
2668-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
2669-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01));
2670-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps));
2667+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
2668+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
2669+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
2670+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
2671+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
2672+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
2673+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
2674+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
2675+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
2676+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
2677+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
2678+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
26712679
// This is local memory - the size depends on subgroup size.
2672-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs, NULL));
2680+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL));
26732681

26742682
#ifdef GGML_OPENCL_PROFILING
26752683
cl_event evt;

ggml/src/ggml-opencl/kernels/ggml-opencl.cl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,14 +575,23 @@ kernel void kernel_rms_norm(
575575
global float * dst,
576576
ulong offsetd,
577577
int ne00,
578+
int ne01,
579+
int ne02,
580+
int ne03,
578581
ulong nb01,
582+
ulong nb02,
583+
ulong nb03,
579584
float eps,
580585
local float * sum // Note, the size depends on number of subgroups
581586
) {
582587
src0 = (global void*)((global char*)src0 + offset0);
583588
dst = (global float*)((global char*)dst + offsetd);
584589

585-
global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01);
590+
int i03 = get_group_id(2);
591+
int i02 = get_group_id(1);
592+
int i01 = get_group_id(0);
593+
594+
global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
586595
global float * x_scalar = (global float *) x;
587596
float4 sumf = 0;
588597
float all_sum = 0;
@@ -616,7 +625,7 @@ kernel void kernel_rms_norm(
616625
const float mean = sum[0];
617626
const float scale = 1.0f/sqrt(mean + eps);
618627

619-
global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00);
628+
global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
620629
global float * y_scalar = (global float *) y;
621630
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
622631
y[i00] = x[i00] * scale;

0 commit comments

Comments
 (0)