Skip to content

Commit 9d7137f

Browse files
shaofeiqiggerganov
authored andcommitted
opencl: add SOFTPLUS op support (llama/18726)
1 parent d8faba2 commit 9d7137f

File tree

3 files changed

+227
-0
lines changed

3 files changed

+227
-0
lines changed

src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ set(GGML_OPENCL_KERNELS
122122
upscale
123123
tanh
124124
expm1
125+
softplus
125126
pad
126127
repeat
127128
mul_mat_f16_f32

src/ggml-opencl/ggml-opencl.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,8 @@ struct ggml_backend_opencl_context {
540540
cl_kernel kernel_tanh_f16_nd;
541541
cl_kernel kernel_expm1_f32_nd;
542542
cl_kernel kernel_expm1_f16_nd;
543+
cl_kernel kernel_softplus_f32_nd;
544+
cl_kernel kernel_softplus_f16_nd;
543545
cl_kernel kernel_upscale;
544546
cl_kernel kernel_upscale_bilinear;
545547
cl_kernel kernel_concat_f32_contiguous;
@@ -1826,6 +1828,31 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
18261828
CL_CHECK(clReleaseProgram(prog));
18271829
}
18281830

1831+
// softplus
1832+
{
1833+
#ifdef GGML_OPENCL_EMBED_KERNELS
1834+
const std::string kernel_src {
1835+
#include "softplus.cl.h"
1836+
};
1837+
#else
1838+
const std::string kernel_src = read_file("softplus.cl");
1839+
#endif
1840+
cl_program prog;
1841+
if (!kernel_src.empty()) {
1842+
prog =
1843+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1844+
CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err));
1845+
CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err));
1846+
GGML_LOG_CONT(".");
1847+
} else {
1848+
GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n");
1849+
prog = nullptr;
1850+
backend_ctx->kernel_softplus_f32_nd = nullptr;
1851+
backend_ctx->kernel_softplus_f16_nd = nullptr;
1852+
}
1853+
CL_CHECK(clReleaseProgram(prog));
1854+
}
1855+
18291856
// upscale
18301857
{
18311858
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3138,6 +3165,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
31383165
case GGML_UNARY_OP_EXPM1:
31393166
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
31403167
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
3168+
case GGML_UNARY_OP_SOFTPLUS:
3169+
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
3170+
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
31413171
default:
31423172
return false;
31433173
}
@@ -6596,6 +6626,108 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
65966626
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
65976627
}
65986628

6629+
static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6630+
GGML_ASSERT(src0);
6631+
GGML_ASSERT(src0->extra);
6632+
GGML_ASSERT(dst);
6633+
GGML_ASSERT(dst->extra);
6634+
6635+
UNUSED(src1);
6636+
6637+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
6638+
6639+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
6640+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
6641+
6642+
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
6643+
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
6644+
6645+
cl_kernel kernel;
6646+
if (dst->type == GGML_TYPE_F32) {
6647+
kernel = backend_ctx->kernel_softplus_f32_nd;
6648+
} else if (dst->type == GGML_TYPE_F16) {
6649+
kernel = backend_ctx->kernel_softplus_f16_nd;
6650+
} else {
6651+
GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus");
6652+
}
6653+
GGML_ASSERT(kernel != nullptr);
6654+
6655+
const int ne00 = src0->ne[0];
6656+
const int ne01 = src0->ne[1];
6657+
const int ne02 = src0->ne[2];
6658+
const int ne03 = src0->ne[3];
6659+
6660+
const cl_ulong nb00 = src0->nb[0];
6661+
const cl_ulong nb01 = src0->nb[1];
6662+
const cl_ulong nb02 = src0->nb[2];
6663+
const cl_ulong nb03 = src0->nb[3];
6664+
6665+
const int ne10 = dst->ne[0];
6666+
const int ne11 = dst->ne[1];
6667+
const int ne12 = dst->ne[2];
6668+
const int ne13 = dst->ne[3];
6669+
6670+
const cl_ulong nb10 = dst->nb[0];
6671+
const cl_ulong nb11 = dst->nb[1];
6672+
const cl_ulong nb12 = dst->nb[2];
6673+
const cl_ulong nb13 = dst->nb[3];
6674+
6675+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
6676+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
6677+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
6678+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
6679+
6680+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
6681+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
6682+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
6683+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
6684+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
6685+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
6686+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
6687+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
6688+
6689+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
6690+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
6691+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
6692+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
6693+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
6694+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
6695+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
6696+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
6697+
6698+
size_t global_work_size[3];
6699+
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
6700+
return;
6701+
}
6702+
global_work_size[0] = (size_t)ne10;
6703+
global_work_size[1] = (size_t)ne11;
6704+
global_work_size[2] = (size_t)ne12;
6705+
6706+
size_t lws0 = 16, lws1 = 4, lws2 = 1;
6707+
if (ne10 < 16) lws0 = ne10;
6708+
if (ne11 < 4) lws1 = ne11;
6709+
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
6710+
6711+
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
6712+
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
6713+
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
6714+
6715+
6716+
size_t local_work_size[] = {lws0, lws1, lws2};
6717+
6718+
size_t* local_work_size_ptr = local_work_size;
6719+
if (!backend_ctx->non_uniform_workgroups) {
6720+
if (global_work_size[0] % local_work_size[0] != 0 ||
6721+
global_work_size[1] % local_work_size[1] != 0 ||
6722+
global_work_size[2] % local_work_size[2] != 0) {
6723+
local_work_size_ptr = NULL;
6724+
}
6725+
}
6726+
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
6727+
6728+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
6729+
}
6730+
65996731
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
66006732
GGML_ASSERT(src0);
66016733
GGML_ASSERT(src0->extra);
@@ -9775,6 +9907,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
97759907
}
97769908
func = ggml_cl_expm1;
97779909
break;
9910+
case GGML_UNARY_OP_SOFTPLUS:
9911+
if (!any_on_device) {
9912+
return false;
9913+
}
9914+
func = ggml_cl_softplus;
9915+
break;
97789916
default:
97799917
return false;
97809918
} break;
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
//------------------------------------------------------------------------------
4+
// softplus
5+
//------------------------------------------------------------------------------
6+
inline float softplus_f32(float x){
7+
float ax = fabs(x);
8+
float m = fmax(x, 0.0f);
9+
return log1p(exp(-ax)) + m;
10+
}
11+
12+
kernel void kernel_softplus_f32_nd(
13+
global void * p_src0_base,
14+
ulong off_src0_abs,
15+
global void * p_dst_base,
16+
ulong off_dst_abs,
17+
int ne00,
18+
int ne01,
19+
int ne02,
20+
int ne03,
21+
ulong nb00,
22+
ulong nb01,
23+
ulong nb02,
24+
ulong nb03,
25+
int ne10,
26+
int ne11,
27+
int ne12,
28+
int ne13,
29+
ulong nb10,
30+
ulong nb11,
31+
ulong nb12,
32+
ulong nb13
33+
) {
34+
int i0 = get_global_id(0);
35+
int i1 = get_global_id(1);
36+
int i2 = get_global_id(2);
37+
38+
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
39+
for (int i3 = 0; i3 < ne13; ++i3) {
40+
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
41+
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
42+
43+
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
44+
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
45+
46+
*dst_val_ptr = softplus_f32(*src_val_ptr);
47+
}
48+
}
49+
}
50+
51+
kernel void kernel_softplus_f16_nd(
52+
global void * p_src0_base,
53+
ulong off_src0_abs,
54+
global void * p_dst_base,
55+
ulong off_dst_abs,
56+
int ne00,
57+
int ne01,
58+
int ne02,
59+
int ne03,
60+
ulong nb00,
61+
ulong nb01,
62+
ulong nb02,
63+
ulong nb03,
64+
int ne10,
65+
int ne11,
66+
int ne12,
67+
int ne13,
68+
ulong nb10,
69+
ulong nb11,
70+
ulong nb12,
71+
ulong nb13
72+
) {
73+
int i0 = get_global_id(0);
74+
int i1 = get_global_id(1);
75+
int i2 = get_global_id(2);
76+
77+
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
78+
for (int i3 = 0; i3 < ne13; ++i3) {
79+
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
80+
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
81+
82+
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
83+
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
84+
85+
*dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr)));
86+
}
87+
}
88+
}

0 commit comments

Comments
 (0)