Skip to content

Commit 2e87a69

Browse files
committed
opencl: support ne3 in get_rows
1 parent d413dca commit 2e87a69

File tree

2 files changed

+60
-29
lines changed

2 files changed

+60
-29
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3780,15 +3780,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
37803780
GGML_ASSERT(dst);
37813781
GGML_ASSERT(dst->extra);
37823782

3783-
const int ne00 = src0 ? src0->ne[0] : 0;
3784-
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
3785-
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
3786-
const int ne10 = src1 ? src1->ne[0] : 0;
3787-
const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
3788-
const int ne11 = src1 ? src1->ne[1] : 0;
3789-
const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
3790-
const cl_ulong nb1 = dst ? dst->nb[1] : 0;
3791-
const cl_ulong nb2 = dst ? dst->nb[2] : 0;
3783+
const int ne00 = src0->ne[0];
3784+
const cl_ulong nb01 = src0->nb[1];
3785+
const cl_ulong nb02 = src0->nb[2];
3786+
const cl_ulong nb03 = src0->nb[3];
3787+
const int ne10 = src1->ne[0];
3788+
const cl_ulong nb10 = src1->nb[0];
3789+
const int ne11 = src1->ne[1];
3790+
const int ne12 = src1->ne[2];
3791+
const cl_ulong nb11 = src1->nb[1];
3792+
const cl_ulong nb12 = src1->nb[2];
3793+
const cl_ulong nb1 = dst->nb[1];
3794+
const cl_ulong nb2 = dst->nb[2];
3795+
const cl_ulong nb3 = dst->nb[3];
37923796

37933797
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
37943798

@@ -3825,14 +3829,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
38253829
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
38263830
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
38273831
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
3828-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
3829-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10));
3830-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
3831-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
3832-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
3833-
3834-
size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1};
3835-
size_t local_work_size[] = {1, 1, 1};
3832+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
3833+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
3834+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10));
3835+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
3836+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
3837+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
3838+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
3839+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
3840+
3841+
size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
3842+
size_t local_work_size[] = {64, 1, 1};
38363843

38373844
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
38383845
}

ggml/src/ggml-opencl/kernels/get_rows.cl

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,26 +69,34 @@ kernel void kernel_get_rows_f32(
6969
int ne00,
7070
ulong nb01,
7171
ulong nb02,
72+
ulong nb03,
7273
int ne10,
7374
ulong nb10,
7475
ulong nb11,
76+
ulong nb12,
7577
ulong nb1,
76-
ulong nb2
78+
ulong nb2,
79+
ulong nb3
7780
) {
7881
src0 = (global void*)((global char*)src0 + offset0);
7982
src1 = (global int*)((global char*)src1 + offset1);
8083
dst = (global float*)((global char*)dst + offsetd);
8184

8285
int i10 = get_group_id(0);
8386
int i11 = get_group_id(1);
87+
int i12 = get_group_id(2);
8488

85-
int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
89+
int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
8690

8791
int i02 = i11;
92+
int i03 = i12;
8893

8994
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
90-
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
91-
((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
95+
if (ind >= ne00) {
96+
return;
97+
}
98+
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
99+
((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
92100
}
93101
}
94102

@@ -102,26 +110,34 @@ kernel void kernel_get_rows_f16(
102110
int ne00,
103111
ulong nb01,
104112
ulong nb02,
113+
ulong nb03,
105114
int ne10,
106115
ulong nb10,
107116
ulong nb11,
117+
ulong nb12,
108118
ulong nb1,
109-
ulong nb2
119+
ulong nb2,
120+
ulong nb3
110121
) {
111122
src0 = (global void*)((global char*)src0 + offset0);
112123
src1 = (global int*)((global char*)src1 + offset1);
113124
dst = (global float*)((global char*)dst + offsetd);
114125

115126
int i10 = get_group_id(0);
116127
int i11 = get_group_id(1);
128+
int i12 = get_group_id(2);
117129

118-
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
130+
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
119131

120132
int i02 = i11;
133+
int i03 = i12;
121134

122135
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
123-
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
124-
((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
136+
if (ind >= ne00) {
137+
return;
138+
}
139+
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
140+
((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
125141
}
126142
}
127143

@@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0(
135151
int ne00,
136152
ulong nb01,
137153
ulong nb02,
154+
ulong nb03,
138155
int ne10,
139156
ulong nb10,
140157
ulong nb11,
158+
ulong nb12,
141159
ulong nb1,
142-
ulong nb2
160+
ulong nb2,
161+
ulong nb3
143162
) {
144163
src0 = (global void*)((global char*)src0 + offset0);
145164
src1 = (global int*)((global char*)src1 + offset1);
@@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0(
149168

150169
int i10 = get_group_id(0);
151170
int i11 = get_group_id(1);
171+
int i12 = get_group_id(2);
152172

153-
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
173+
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
154174

155175
int i02 = i11;
176+
int i03 = i12;
156177

157178
for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
158179
float16 temp;
180+
if (ind >= ne00) {
181+
return;
182+
}
159183
dequantize_q4_0_f32(
160-
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
161-
*(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
184+
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
185+
*(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
162186
}
163187
}

0 commit comments

Comments
 (0)