diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 727163b7fdf95..02c69fad5a31f 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3780,15 +3780,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const int ne00 = src0 ? src0->ne[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const int ne10 = src1 ? src1->ne[0] : 0; - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb1 = dst ? dst->nb[1] : 0; - const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const int ne00 = src0->ne[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + const int ne10 = src1->ne[0]; + const cl_ulong nb10 = src1->nb[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -3825,14 +3829,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); - - size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1}; - size_t local_work_size[] = {1, 1, 1}; + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); + + size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12}; + size_t local_work_size[] = {64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl index b3fea2923df8f..c2962edc98372 100644 --- a/ggml/src/ggml-opencl/kernels/get_rows.cl +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -69,11 +69,14 @@ kernel void kernel_get_rows_f32( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -81,14 +84,19 @@ kernel void kernel_get_rows_f32( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + if (ind >= ne00) { + return; + } + ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; } } @@ -102,11 +110,14 @@ kernel void kernel_get_rows_f16( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -114,14 +125,19 @@ kernel void kernel_get_rows_f16( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + if (ind >= ne00) { + return; + } + ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; } } @@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { float16 temp; + if (ind >= ne00) { + return; + } dequantize_q4_0_f32( - ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); - *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp; } }