Skip to content

Commit d252563

Browse files
committed
opencl: broadcast for soft_max
1 parent e75ba4c commit d252563

File tree

5 files changed

+138
-49
lines changed

5 files changed

+138
-49
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5757,19 +5757,32 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
57575757

57585758
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
57595759

5760-
const int ne00 = src0 ? src0->ne[0] : 0;
5761-
const int ne01 = src0 ? src0->ne[1] : 0;
5762-
const int ne02 = src0 ? src0->ne[2] : 0;
5763-
const int ne03 = src0 ? src0->ne[3] : 0;
5760+
const int ne00 = src0->ne[0];
5761+
const int ne01 = src0->ne[1];
5762+
const int ne02 = src0->ne[2];
5763+
const int ne03 = src0->ne[3];
5764+
5765+
const cl_long nb01 = src0->nb[1];
5766+
const cl_long nb02 = src0->nb[2];
5767+
const cl_long nb03 = src0->nb[3];
5768+
5769+
const int ne11 = src1 ? src1->ne[1] : 0;
5770+
const int ne12 = src1 ? src1->ne[2] : 0;
5771+
const int ne13 = src1 ? src1->ne[3] : 0;
5772+
5773+
const cl_long nb11 = src1 ? src1->nb[1] : 0;
5774+
const cl_long nb12 = src1 ? src1->nb[2] : 0;
5775+
const cl_long nb13 = src1 ? src1->nb[3] : 0;
5776+
5777+
const cl_long nb1 = dst->nb[1];
5778+
const cl_long nb2 = dst->nb[2];
5779+
const cl_long nb3 = dst->nb[3];
57645780

57655781
float scale, max_bias;
57665782
memcpy(&scale, dst->op_params + 0, sizeof(float));
57675783
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
57685784

5769-
const int nrows_x = ggml_nrows(src0);
5770-
const int nrows_y = src0->ne[1];
5771-
5772-
const int n_head = nrows_x/nrows_y;
5785+
const int n_head = src0->ne[2];
57735786
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
57745787

57755788
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -5816,11 +5829,23 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
58165829
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
58175830
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
58185831
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5819-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
5820-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
5821-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
5822-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
5823-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
5832+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
5833+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
5834+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
5835+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
5836+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
5837+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne13));
5838+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));
5839+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
5840+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
5841+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb1));
5842+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb2));
5843+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb3));
5844+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &scale));
5845+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &max_bias));
5846+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m0));
5847+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float), &m1));
5848+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &n_head_log2));
58245849

58255850
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
58265851
size_t local_work_size[] = {(size_t)nth, 1, 1};

ggml/src/ggml-opencl/kernels/softmax_4_f16.cl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,48 @@
2222
REQD_SUBGROUP_SIZE_64
2323
#endif
2424
kernel void kernel_soft_max_4_f16(
25-
global float * src0,
25+
global char * src0,
2626
ulong offset0,
27-
global half * src1,
27+
global char * src1,
2828
ulong offset1,
29-
global float * dst,
29+
global char * dst,
3030
ulong offsetd,
3131
int ne00,
3232
int ne01,
3333
int ne02,
34+
ulong nb01,
35+
ulong nb02,
36+
ulong nb03,
37+
int ne11,
38+
int ne12,
39+
int ne13,
40+
ulong nb11,
41+
ulong nb12,
42+
ulong nb13,
43+
ulong nb1,
44+
ulong nb2,
45+
ulong nb3,
3446
float scale,
3547
float max_bias,
3648
float m0,
3749
float m1,
3850
int n_head_log2
3951
) {
40-
src0 = (global float *)((global char *)src0 + offset0);
41-
src1 = (global half *)((global char *)src1 + offset1);
42-
dst = (global float *)((global char *)dst + offsetd);
52+
src0 = src0 + offset0;
53+
src1 = src1 + offset1;
54+
dst = dst + offsetd;
4355

4456
int i03 = get_group_id(2);
4557
int i02 = get_group_id(1);
4658
int i01 = get_group_id(0);
4759

48-
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
49-
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
50-
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
60+
int i13 = i03%ne13;
61+
int i12 = i02%ne12;
62+
int i11 = i01;
63+
64+
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65+
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66+
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
5167

5268
float slope = 1.0f;
5369

ggml/src/ggml-opencl/kernels/softmax_4_f32.cl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,48 @@
2222
REQD_SUBGROUP_SIZE_64
2323
#endif
2424
kernel void kernel_soft_max_4(
25-
global float * src0,
25+
global char * src0,
2626
ulong offset0,
27-
global float * src1,
27+
global char * src1,
2828
ulong offset1,
29-
global float * dst,
29+
global char * dst,
3030
ulong offsetd,
3131
int ne00,
3232
int ne01,
3333
int ne02,
34+
ulong nb01,
35+
ulong nb02,
36+
ulong nb03,
37+
int ne11,
38+
int ne12,
39+
int ne13,
40+
ulong nb11,
41+
ulong nb12,
42+
ulong nb13,
43+
ulong nb1,
44+
ulong nb2,
45+
ulong nb3,
3446
float scale,
3547
float max_bias,
3648
float m0,
3749
float m1,
3850
int n_head_log2
3951
) {
40-
src0 = (global float*)((global char*)src0 + offset0);
41-
src1 = (global float*)((global char*)src1 + offset1);
42-
dst = (global float*)((global char*)dst + offsetd);
52+
src0 = src0 + offset0;
53+
src1 = src1 + offset1;
54+
dst = dst + offsetd;
4355

4456
int i03 = get_group_id(2);
4557
int i02 = get_group_id(1);
4658
int i01 = get_group_id(0);
4759

48-
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
49-
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
50-
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
60+
int i13 = i03%ne13;
61+
int i12 = i02%ne12;
62+
int i11 = i01;
63+
64+
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65+
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66+
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
5167

5268
float slope = 1.0f;
5369

ggml/src/ggml-opencl/kernels/softmax_f16.cl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,48 @@
2222
REQD_SUBGROUP_SIZE_64
2323
#endif
2424
kernel void kernel_soft_max_f16(
25-
global float * src0,
25+
global char * src0,
2626
ulong offset0,
27-
global half * src1,
27+
global char * src1,
2828
ulong offset1,
29-
global float * dst,
29+
global char * dst,
3030
ulong offsetd,
3131
int ne00,
3232
int ne01,
3333
int ne02,
34+
ulong nb01,
35+
ulong nb02,
36+
ulong nb03,
37+
int ne11,
38+
int ne12,
39+
int ne13,
40+
ulong nb11,
41+
ulong nb12,
42+
ulong nb13,
43+
ulong nb1,
44+
ulong nb2,
45+
ulong nb3,
3446
float scale,
3547
float max_bias,
3648
float m0,
3749
float m1,
3850
int n_head_log2
3951
) {
40-
src0 = (global float *)((global char *)src0 + offset0);
41-
src1 = (global half *)((global char *)src1 + offset1);
42-
dst = (global float *)((global char *)dst + offsetd);
52+
src0 = src0 + offset0;
53+
src1 = src1 + offset1;
54+
dst = dst + offsetd;
4355

4456
int i03 = get_group_id(2);
4557
int i02 = get_group_id(1);
4658
int i01 = get_group_id(0);
4759

48-
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
49-
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
50-
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
60+
int i13 = i03%ne13;
61+
int i12 = i02%ne12;
62+
int i11 = i01;
63+
64+
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65+
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66+
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
5167

5268
float slope = 1.0f;
5369

ggml/src/ggml-opencl/kernels/softmax_f32.cl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,48 @@
2222
REQD_SUBGROUP_SIZE_64
2323
#endif
2424
kernel void kernel_soft_max(
25-
global float * src0,
25+
global char * src0,
2626
ulong offset0,
27-
global float * src1,
27+
global char * src1,
2828
ulong offset1,
29-
global float * dst,
29+
global char * dst,
3030
ulong offsetd,
3131
int ne00,
3232
int ne01,
3333
int ne02,
34+
ulong nb01,
35+
ulong nb02,
36+
ulong nb03,
37+
int ne11,
38+
int ne12,
39+
int ne13,
40+
ulong nb11,
41+
ulong nb12,
42+
ulong nb13,
43+
ulong nb1,
44+
ulong nb2,
45+
ulong nb3,
3446
float scale,
3547
float max_bias,
3648
float m0,
3749
float m1,
3850
int n_head_log2
3951
) {
40-
src0 = (global float*)((global char*)src0 + offset0);
41-
src1 = (global float*)((global char*)src1 + offset1);
42-
dst = (global float*)((global char*)dst + offsetd);
52+
src0 = src0 + offset0;
53+
src1 = src1 + offset1;
54+
dst = dst + offsetd;
4355

4456
int i03 = get_group_id(2);
4557
int i02 = get_group_id(1);
4658
int i01 = get_group_id(0);
4759

48-
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
49-
global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
50-
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
60+
int i13 = i03%ne13;
61+
int i12 = i02%ne12;
62+
int i11 = i01;
63+
64+
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65+
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66+
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
5167

5268
float slope = 1.0f;
5369

0 commit comments

Comments
 (0)