Skip to content

Commit 60a7658

Browse files
authored
opencl: allow mixed f16/f32 add (#15140)
1 parent efe3a90 commit 60a7658

File tree

2 files changed

+162
-79
lines changed

2 files changed

+162
-79
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 120 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,6 +2481,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
24812481
case GGML_OP_SCALE:
24822482
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
24832483
case GGML_OP_ADD:
2484+
if (op->type == GGML_TYPE_F16) {
2485+
const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;
2486+
const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;
2487+
if (src0_ok && src1_ok) {
2488+
return true;
2489+
}
2490+
}
24842491
case GGML_OP_MUL:
24852492
case GGML_OP_DIV:
24862493
case GGML_OP_SUB:
@@ -3717,34 +3724,30 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37173724
GGML_ASSERT(dst);
37183725
GGML_ASSERT(dst->extra);
37193726

3720-
GGML_ASSERT(src0->type == src1->type);
3721-
GGML_ASSERT(src0->type == dst->type);
3722-
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3723-
3724-
const int ne00 = src0->ne[0];
3725-
const int ne01 = src0->ne[1];
3726-
const int ne02 = src0->ne[2];
3727-
const int ne03 = src0->ne[3];
3727+
const int ne00 = src0->ne[0];
3728+
const int ne01 = src0->ne[1];
3729+
const int ne02 = src0->ne[2];
3730+
const int ne03 = src0->ne[3];
37283731

37293732
const cl_ulong nb00 = src0->nb[0];
37303733
const cl_ulong nb01 = src0->nb[1];
37313734
const cl_ulong nb02 = src0->nb[2];
37323735
const cl_ulong nb03 = src0->nb[3];
37333736

3734-
const int ne10 = src1->ne[0];
3735-
const int ne11 = src1->ne[1];
3736-
const int ne12 = src1->ne[2];
3737-
const int ne13 = src1->ne[3]; UNUSED(ne13);
3737+
const int ne10 = src1->ne[0];
3738+
const int ne11 = src1->ne[1];
3739+
const int ne12 = src1->ne[2];
3740+
const int ne13 = src1->ne[3];
37383741

37393742
const cl_ulong nb10 = src1->nb[0];
37403743
const cl_ulong nb11 = src1->nb[1];
37413744
const cl_ulong nb12 = src1->nb[2];
3742-
const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
3745+
const cl_ulong nb13 = src1->nb[3];
37433746

3744-
const int ne0 = dst->ne[0];
3745-
const int ne1 = dst->ne[1];
3746-
const int ne2 = dst->ne[2];
3747-
const int ne3 = dst->ne[3];
3747+
const int ne0 = dst->ne[0];
3748+
const int ne1 = dst->ne[1];
3749+
const int ne2 = dst->ne[2];
3750+
const int ne3 = dst->ne[3];
37483751

37493752
const cl_ulong nb0 = dst->nb[0];
37503753
const cl_ulong nb1 = dst->nb[1];
@@ -3761,68 +3764,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37613764
cl_ulong offset1 = extra1->offset + src1->view_offs;
37623765
cl_ulong offsetd = extrad->offset + dst->view_offs;
37633766

3764-
bool bcast_row = false;
37653767
cl_kernel kernel;
37663768

3767-
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
3768-
GGML_ASSERT(ggml_is_contiguous(src0));
3769+
const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;
37693770

3770-
// src1 is a row
3771+
if (bcast_row) {
3772+
GGML_ASSERT(ggml_is_contiguous(src0));
37713773
GGML_ASSERT(ne11 == 1);
3774+
}
37723775

3773-
bcast_row = true;
3774-
int ne = ne00 / 4;
3775-
3776-
if (src0->type == GGML_TYPE_F32) {
3776+
if (dst->type == GGML_TYPE_F32) {
3777+
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
3778+
if (bcast_row) {
37773779
kernel = backend_ctx->kernel_add_row;
3780+
const int ne = ne00 / 4;
3781+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3782+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3783+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3784+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3785+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3786+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3787+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
37783788
} else {
3779-
kernel = backend_ctx->kernel_add_row_f16;
3780-
}
3781-
3782-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3783-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3784-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3785-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3786-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3787-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3788-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3789-
} else {
3790-
if (src0->type == GGML_TYPE_F32) {
37913789
kernel = backend_ctx->kernel_add;
3790+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3791+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3792+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3793+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3794+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3795+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3796+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
3797+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
3798+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
3799+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
3800+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
3801+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
3802+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
3803+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
3804+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
3805+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
3806+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
3807+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
3808+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
3809+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
3810+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
3811+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
3812+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
3813+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
3814+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
3815+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
3816+
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
3817+
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
3818+
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
3819+
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
3820+
}
3821+
} else if (dst->type == GGML_TYPE_F16) {
3822+
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
3823+
GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
3824+
const int type_src0 = (src0->type == GGML_TYPE_F32);
3825+
const int type_src1 = (src1->type == GGML_TYPE_F32);
3826+
if (bcast_row) {
3827+
kernel = backend_ctx->kernel_add_row_f16;
3828+
const int ne = ne00 / 4;
3829+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3830+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3831+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3832+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3833+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3834+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3835+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3836+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &type_src0));
3837+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &type_src1));
37923838
} else {
37933839
kernel = backend_ctx->kernel_add_f16;
3840+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3841+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3842+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3843+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3844+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3845+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3846+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
3847+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
3848+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
3849+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
3850+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
3851+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
3852+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
3853+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
3854+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
3855+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
3856+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
3857+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
3858+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
3859+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
3860+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
3861+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
3862+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
3863+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
3864+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
3865+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
3866+
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
3867+
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
3868+
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
3869+
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
3870+
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &type_src0));
3871+
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int), &type_src1));
37943872
}
3795-
3796-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3797-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3798-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3799-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3800-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3801-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3802-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
3803-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
3804-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
3805-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
3806-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
3807-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
3808-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
3809-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
3810-
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
3811-
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
3812-
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
3813-
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
3814-
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
3815-
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
3816-
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
3817-
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
3818-
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
3819-
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
3820-
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
3821-
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
3822-
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
3823-
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
3824-
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
3825-
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
3873+
} else {
3874+
GGML_ASSERT(false && "unsupported data types for add");
38263875
}
38273876

38283877
if (bcast_row) {
@@ -3832,13 +3881,13 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
38323881

38333882
size_t * local_work_size_ptr = local_work_size;
38343883
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
3835-
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
3884+
local_work_size_ptr = nullptr;
38363885
}
38373886

3838-
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
3887+
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);
38393888
} else {
38403889
unsigned int nth = MIN(64, ne0);
3841-
size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
3890+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
38423891
size_t local_work_size[] = {nth, 1, 1};
38433892

38443893
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);

ggml/src/ggml-opencl/kernels/add.cl

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ kernel void kernel_add_f16(
112112
ulong nb0,
113113
ulong nb1,
114114
ulong nb2,
115-
ulong nb3
115+
ulong nb3,
116+
int type_src0,
117+
int type_src1
116118
) {
117119
src0 = src0 + offset0;
118120
src1 = src1 + offset1;
@@ -132,25 +134,57 @@ kernel void kernel_add_f16(
132134

133135
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
134136
const int i10 = i0 % ne10;
135-
*((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) + *((global half *)(src1_ptr + i10*nb10));
137+
138+
half v0, v1;
139+
if (type_src0 == 1) {
140+
v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
141+
} else {
142+
v0 = *((global half *)(src0_ptr + i0*nb00));
143+
}
144+
145+
if (type_src1 == 1) {
146+
v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
147+
} else {
148+
v1 = *((global half *)(src1_ptr + i10*nb10));
149+
}
150+
151+
*((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
136152
}
137153
}
138154

139155
kernel void kernel_add_row_f16(
140-
global half4 * src0,
156+
global char * src0,
141157
ulong offset0,
142-
global half4 * src1,
158+
global char * src1,
143159
ulong offset1,
144160
global half4 * dst,
145161
ulong offsetd,
146-
int ne
162+
int ne,
163+
int type_src0,
164+
int type_src1
147165
) {
148-
src0 = (global half4*)((global char*)src0 + offset0);
149-
src1 = (global half4*)((global char*)src1 + offset1);
150166
dst = (global half4*)((global char*)dst + offsetd);
151167

152168
// This performs better than using %.
153169
uint gid = get_global_id(0);
154170
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
155-
dst[gid] = src0[gid] + src1[idx1];
171+
172+
half4 v0, v1;
173+
if (type_src0 == 1) {
174+
global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
175+
v0 = convert_half4(src0_f32[gid]);
176+
} else {
177+
global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
178+
v0 = src0_f16[gid];
179+
}
180+
181+
if (type_src1 == 1) {
182+
global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
183+
v1 = convert_half4(src1_f32[idx1]);
184+
} else {
185+
global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
186+
v1 = src1_f16[idx1];
187+
}
188+
189+
dst[gid] = v0 + v1;
156190
}

0 commit comments

Comments
 (0)