Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3780,15 +3780,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);

const int ne00 = src0 ? src0->ne[0] : 0;
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
const int ne10 = src1 ? src1->ne[0] : 0;
const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
const int ne11 = src1 ? src1->ne[1] : 0;
const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
const cl_ulong nb1 = dst ? dst->nb[1] : 0;
const cl_ulong nb2 = dst ? dst->nb[2] : 0;
const int ne00 = src0->ne[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = src1->ne[0];
const cl_ulong nb10 = src1->nb[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];

ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

Expand Down Expand Up @@ -3825,14 +3829,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));

size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1};
size_t local_work_size[] = {1, 1, 1};
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));

size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
size_t local_work_size[] = {64, 1, 1};

backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
Expand Down
48 changes: 36 additions & 12 deletions ggml/src/ggml-opencl/kernels/get_rows.cl
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,34 @@ kernel void kernel_get_rows_f32(
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb1,
ulong nb2
ulong nb2,
ulong nb3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);

int i10 = get_group_id(0);
int i11 = get_group_id(1);
int i12 = get_group_id(2);

int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];

int i02 = i11;
int i03 = i12;

for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
if (ind >= ne00) {
return;
}
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
}
}

Expand All @@ -102,26 +110,34 @@ kernel void kernel_get_rows_f16(
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb1,
ulong nb2
ulong nb2,
ulong nb3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);

int i10 = get_group_id(0);
int i11 = get_group_id(1);
int i12 = get_group_id(2);

int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];

int i02 = i11;
int i03 = i12;

for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
if (ind >= ne00) {
return;
}
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
}
}

Expand All @@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0(
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb1,
ulong nb2
ulong nb2,
ulong nb3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
Expand All @@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0(

int i10 = get_group_id(0);
int i11 = get_group_id(1);
int i12 = get_group_id(2);

int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];

int i02 = i11;
int i03 = i12;

for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
float16 temp;
if (ind >= ne00) {
return;
}
dequantize_q4_0_f32(
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
*(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
*(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
}
}