Skip to content

Commit e6e3918

Browse files
committed
opencl: add add_rows
1 parent c785441 commit e6e3918

File tree

3 files changed

+129
-0
lines changed

3 files changed

+129
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ set(GGML_OPENCL_KERNELS
9090
softmax_4_f16
9191
softmax_f32
9292
softmax_f16
93+
sum_rows
9394
transpose
9495
)
9596

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ struct ggml_backend_opencl_context {
310310
cl_program program_softmax_4_f32;
311311
cl_program program_softmax_4_f16;
312312
cl_program program_argsort_f32_i32;
313+
cl_program program_sum_rows_f32;
313314

314315
cl_kernel kernel_add, kernel_add_row;
315316
cl_kernel kernel_mul, kernel_mul_row;
@@ -342,6 +343,7 @@ struct ggml_backend_opencl_context {
342343
cl_kernel kernel_mul_mv_q6_K_f32;
343344
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
344345
cl_kernel kernel_argsort_f32_i32;
346+
cl_kernel kernel_sum_rows_f32;
345347

346348
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
347349
// Transpose kernels
@@ -1022,6 +1024,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
10221024
GGML_LOG_CONT(".");
10231025
}
10241026

1027+
// sum_rows
1028+
{
1029+
#ifdef GGML_OPENCL_EMBED_KERNELS
1030+
const std::string kernel_src {
1031+
#include "sum_rows.cl.h"
1032+
};
1033+
#else
1034+
const std::string kernel_src = read_file("sum_rows.cl");
1035+
#endif
1036+
backend_ctx->program_sum_rows_f32 =
1037+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1038+
1039+
CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err));
1040+
GGML_LOG_CONT(".");
1041+
}
1042+
10251043
// Adreno kernels
10261044
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
10271045
// transpose
@@ -1951,6 +1969,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
19511969
return true;
19521970
case GGML_OP_ARGSORT:
19531971
return op->src[0]->type == GGML_TYPE_F32;
1972+
case GGML_OP_SUM_ROWS:
1973+
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
19541974
default:
19551975
return false;
19561976
}
@@ -5194,6 +5214,69 @@ static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, co
51945214
#endif
51955215
}
51965216

5217+
static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5218+
GGML_ASSERT(src0);
5219+
GGML_ASSERT(src0->extra);
5220+
GGML_ASSERT(dst);
5221+
GGML_ASSERT(dst->extra);
5222+
GGML_UNUSED(src1);
5223+
5224+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
5225+
GGML_ASSERT(ggml_is_contiguous(src0));
5226+
5227+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5228+
cl_command_queue queue = backend_ctx->queue;
5229+
5230+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5231+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5232+
5233+
cl_ulong offset0 = extra0->offset + src0->view_offs;
5234+
cl_ulong offsetd = extrad->offset + dst->view_offs;
5235+
5236+
const int ne00 = src0->ne[0];
5237+
const int ne01 = src0->ne[1];
5238+
const int ne02 = src0->ne[2];
5239+
const int ne03 = src0->ne[3];
5240+
5241+
const cl_ulong nb01 = src0->nb[1];
5242+
const cl_ulong nb02 = src0->nb[2];
5243+
const cl_ulong nb03 = src0->nb[3];
5244+
5245+
const cl_ulong nb1 = dst->nb[1];
5246+
const cl_ulong nb2 = dst->nb[2];
5247+
const cl_ulong nb3 = dst->nb[3];
5248+
5249+
cl_kernel kernel = backend_ctx->kernel_sum_rows_f32;
5250+
5251+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5252+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5253+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
5254+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
5255+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
5256+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
5257+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
5258+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
5259+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
5260+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
5261+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
5262+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1));
5263+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
5264+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
5265+
5266+
size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
5267+
size_t local_work_size[] = {(size_t)64, 1, 1};
5268+
5269+
#ifdef GGML_OPENCL_PROFILING
5270+
cl_event evt;
5271+
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
5272+
5273+
g_profiling_info.emplace_back();
5274+
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
5275+
#else
5276+
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
5277+
#endif
5278+
}
5279+
51975280
//------------------------------------------------------------------------------
51985281
// Op offloading
51995282
//------------------------------------------------------------------------------
@@ -5346,6 +5429,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
53465429
}
53475430
func = ggml_cl_argsort;
53485431
break;
5432+
case GGML_OP_SUM_ROWS:
5433+
if (!any_on_device) {
5434+
return false;
5435+
}
5436+
func = ggml_cl_sum_rows;
5437+
break;
53495438
default:
53505439
return false;
53515440
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
kernel void kernel_sum_rows_f32(
3+
global float * src0,
4+
ulong offset0,
5+
global float * dst,
6+
ulong offsetd,
7+
int ne00,
8+
int ne01,
9+
int ne02,
10+
int ne03,
11+
ulong nb01,
12+
ulong nb02,
13+
ulong nb03,
14+
ulong nb1,
15+
ulong nb2,
16+
ulong nb3
17+
) {
18+
src0 = (global float *)((global char *)src0 + offset0);
19+
dst = (global float *)((global char *)dst + offsetd);
20+
21+
int i3 = get_global_id(2);
22+
int i2 = get_global_id(1);
23+
int i1 = get_global_id(0);
24+
25+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
26+
return;
27+
}
28+
29+
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
30+
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
31+
32+
float row_sum = 0;
33+
34+
for (int i0 = 0; i0 < ne00; i0++) {
35+
row_sum += src_row[i0];
36+
}
37+
38+
dst_row[0] = row_sum;
39+
}

0 commit comments

Comments
 (0)