Skip to content

Commit de864ca

Browse files
committed
opencl: improve mul_mv_id_q8_0_f32_flat
1 parent 2f89151 commit de864ca

File tree

1 file changed

+105
-28
lines changed

1 file changed

+105
-28
lines changed

ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl

Lines changed: 105 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
6868
src2 = (global char *)((global char *)src2 + offset2);
6969
dst = (global char *)((global char *)dst + offsetd);
7070

71-
int iid1 = get_group_id(2)/ne20;
72-
int idx = get_group_id(2)%ne20;
71+
int iid1 = (int)get_group_id(2)/ne20;
72+
int idx = (int)get_group_id(2)%ne20;
7373

7474
int i02 = ((global int *) (src2 + iid1*nb21))[idx];
7575

@@ -80,7 +80,8 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
8080
int i2 = i12_;
8181

8282
// 34 == sizeof(block_q8_0)
83-
ulong src0_off = i02*nb02/34;
83+
uint src0_off = i02*nb02;
84+
src0_off /= 34;
8485

8586
global char * src0_q_cur = src0_q + src0_off*sizeof(char)*QK8_0;
8687
global half * src0_d_cur = src0_d + src0_off;
@@ -99,47 +100,123 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
99100
global float * y = (global float *) (src1_cur + offset_src1);
100101

101102
// pointers to src0 rows
102-
global char * ax[N_R0_Q8_0];
103-
global half * ad[N_R0_Q8_0];
104-
for (int row = 0; row < N_R0_Q8_0; ++row) {
105-
ulong offset_src0 = (first_row + row)*nb01/34;
106-
ax[row] = (global char *) ((global char *) src0_q_cur + offset_src0*sizeof(char)*QK8_0);
107-
ad[row] = (global half *) ((global char *) src0_d_cur + offset_src0*sizeof(half));
108-
}
103+
uint offset_src0_base = first_row*nb01;
104+
105+
global char * ax0, * ax1, * ax2, * ax3;
106+
global half * ad0, * ad1, * ad2, * ad3;
107+
uint offset_src0;
108+
109+
offset_src0 = offset_src0_base + 0*nb01;
110+
offset_src0 = offset_src0/34;
111+
ax0 = (global char *) ((global char *) src0_q_cur + offset_src0*sizeof(char)*QK8_0);
112+
ad0 = (global half *) ((global char *) src0_d_cur + offset_src0*sizeof(half));
109113

110-
float yl[NB_Q8_0];
111-
float sumf[N_R0_Q8_0] = { 0.f };
114+
offset_src0 = offset_src0_base + 1*nb01;
115+
offset_src0 = offset_src0/34;
116+
ax1 = (global char *) ((global char *) src0_q_cur + offset_src0*sizeof(char)*QK8_0);
117+
ad1 = (global half *) ((global char *) src0_d_cur + offset_src0*sizeof(half));
118+
119+
offset_src0 = offset_src0_base + 2*nb01;
120+
offset_src0 = offset_src0/34;
121+
ax2 = (global char *) ((global char *) src0_q_cur + offset_src0*sizeof(char)*QK8_0);
122+
ad2 = (global half *) ((global char *) src0_d_cur + offset_src0*sizeof(half));
123+
124+
offset_src0 = offset_src0_base + 3*nb01;
125+
offset_src0 = offset_src0/34;
126+
ax3 = (global char *) ((global char *) src0_q_cur + offset_src0*sizeof(char)*QK8_0);
127+
ad3 = (global half *) ((global char *) src0_d_cur + offset_src0*sizeof(half));
112128

113129
const short ix = get_sub_group_local_id()/4;
114130
const short il = get_sub_group_local_id()%4;
115131

116132
global float * yb = y + ix*QK8_0 + il*NB_Q8_0;
117133

134+
float8 yl;
135+
float8 qv;
136+
float4 sumf = 0.f;
137+
float sumq = 0.f;
138+
global char * qs;
139+
118140
// each thread handles NB_Q8_0 quants at a time
119141
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) {
120-
for (short i = 0; i < NB_Q8_0; ++i) {
121-
yl[i] = yb[i];
122-
}
123-
124-
for (short row = 0; row < N_R0_Q8_0; row++) {
125-
global char * qs = ax[row] + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
126-
float sumq = 0.f;
127-
for (short iq = 0; iq < NB_Q8_0; ++iq) {
128-
sumq += qs[iq] * yl[iq];
129-
}
130-
sumf[row] += sumq*ad[row][ib];
131-
}
142+
yl = vload8(0, yb);
143+
144+
qs = ax0 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
145+
qv = convert_float8(vload8(0, qs));
146+
sumq = 0;
147+
sumq += qv.s0*yl.s0;
148+
sumq += qv.s1*yl.s1;
149+
sumq += qv.s2*yl.s2;
150+
sumq += qv.s3*yl.s3;
151+
sumq += qv.s4*yl.s4;
152+
sumq += qv.s5*yl.s5;
153+
sumq += qv.s6*yl.s6;
154+
sumq += qv.s7*yl.s7;
155+
sumf.s0 += sumq*ad0[ib];
156+
157+
qs = ax1 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
158+
qv = convert_float8(vload8(0, qs));
159+
sumq = 0;
160+
sumq += qv.s0*yl.s0;
161+
sumq += qv.s1*yl.s1;
162+
sumq += qv.s2*yl.s2;
163+
sumq += qv.s3*yl.s3;
164+
sumq += qv.s4*yl.s4;
165+
sumq += qv.s5*yl.s5;
166+
sumq += qv.s6*yl.s6;
167+
sumq += qv.s7*yl.s7;
168+
sumf.s1 += sumq*ad1[ib];
169+
170+
qs = ax2 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
171+
qv = convert_float8(vload8(0, qs));
172+
sumq = 0;
173+
sumq += qv.s0*yl.s0;
174+
sumq += qv.s1*yl.s1;
175+
sumq += qv.s2*yl.s2;
176+
sumq += qv.s3*yl.s3;
177+
sumq += qv.s4*yl.s4;
178+
sumq += qv.s5*yl.s5;
179+
sumq += qv.s6*yl.s6;
180+
sumq += qv.s7*yl.s7;
181+
sumf.s2 += sumq*ad2[ib];
182+
183+
qs = ax3 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
184+
qv = convert_float8(vload8(0, qs));
185+
sumq = 0;
186+
sumq += qv.s0*yl.s0;
187+
sumq += qv.s1*yl.s1;
188+
sumq += qv.s2*yl.s2;
189+
sumq += qv.s3*yl.s3;
190+
sumq += qv.s4*yl.s4;
191+
sumq += qv.s5*yl.s5;
192+
sumq += qv.s6*yl.s6;
193+
sumq += qv.s7*yl.s7;
194+
sumf.s3 += sumq*ad3[ib];
132195

133196
yb += N_SIMDWIDTH*NB_Q8_0;
134197
}
135198

136199
global float * dst_f32 = (global float *) dst_cur + (ulong)r1*ne0;
137200

138-
for (int row = 0; row < N_R0_Q8_0; ++row) {
139-
float tot = sub_group_reduce_add(sumf[row]);
201+
float4 tot = (float4)(
202+
sub_group_reduce_add(sumf.s0),
203+
sub_group_reduce_add(sumf.s1),
204+
sub_group_reduce_add(sumf.s2),
205+
sub_group_reduce_add(sumf.s3)
206+
);
140207

141-
if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
142-
dst_f32[first_row + row] = tot;
208+
if (get_sub_group_local_id() == 0) {
209+
if (first_row + 0 < ne01) {
210+
dst_f32[first_row + 0] = tot.s0;
211+
}
212+
if (first_row + 1 < ne01) {
213+
dst_f32[first_row + 1] = tot.s1;
214+
}
215+
if (first_row + 2 < ne01) {
216+
dst_f32[first_row + 2] = tot.s2;
217+
}
218+
if (first_row + 3 < ne01) {
219+
dst_f32[first_row + 3] = tot.s3;
143220
}
144221
}
145222
}

0 commit comments

Comments
 (0)