Skip to content

Commit d011e27

Browse files
committed
opencl: support noncontiguous norm
1 parent 94bb63e commit d011e27

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2573,26 +2573,33 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
25732573
memcpy(&eps, dst->op_params, sizeof(float));
25742574

25752575
const int ne00 = src0 ? src0->ne[0] : 0;
2576-
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
2576+
const int ne01 = src0 ? src0->ne[1] : 0;
2577+
const int ne02 = src0 ? src0->ne[2] : 0;
2578+
const int ne03 = src0 ? src0->ne[3] : 0;
25772579

2578-
GGML_ASSERT(ggml_is_contiguous_1(src0));
2580+
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
2581+
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
2582+
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
25792583

25802584
const int nth = MIN(64, ne00);
25812585

25822586
cl_kernel kernel = backend_ctx->kernel_norm;
25832587

2584-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
2585-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
2586-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
2587-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
2588-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
2589-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01));
2590-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps));
2591-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL));
2588+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
2589+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
2590+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
2591+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
2592+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
2593+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
2594+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
2595+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
2596+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
2597+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
2598+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
2599+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
2600+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL));
25922601

2593-
const int64_t nrows = ggml_nrows(src0);
2594-
2595-
size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
2602+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
25962603
size_t local_work_size[] = {(size_t)nth, 1, 1};
25972604

25982605
#ifdef GGML_OPENCL_PROFILING

ggml/src/ggml-opencl/kernels/ggml-opencl.cl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,14 +506,23 @@ kernel void kernel_norm(
506506
global float * dst,
507507
ulong offsetd,
508508
int ne00,
509+
int ne01,
510+
int ne02,
511+
int ne03,
509512
ulong nb01,
513+
ulong nb02,
514+
ulong nb03,
510515
float eps,
511516
local float * sum
512517
) {
513518
src0 = (global void*)((global char*)src0 + offset0);
514519
dst = (global void*)((global char*)dst + offsetd);
515520

516-
global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01);
521+
int i03 = get_group_id(2);
522+
int i02 = get_group_id(1);
523+
int i01 = get_group_id(0);
524+
525+
global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
517526

518527
// MEAN
519528
// parallel sum
@@ -533,7 +542,7 @@ kernel void kernel_norm(
533542

534543
// recenter and VARIANCE
535544
barrier(CLK_LOCAL_MEM_FENCE);
536-
global float * y = dst + get_group_id(0)*ne00;
545+
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
537546
sum[get_local_id(0)] = 0.0f;
538547
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
539548
y[i00] = x[i00] - mean;

0 commit comments

Comments
 (0)