Skip to content

Commit a07dada

Browse files
committed
opencl: improve mul_mv_q8_0_f32_flat
1 parent fd78540 commit a07dada

File tree

1 file changed

+101
-25
lines changed

1 file changed

+101
-25
lines changed

ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl

Lines changed: 101 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -80,47 +80,123 @@ kernel void kernel_mul_mv_q8_0_f32_flat(
8080
global float * y = (global float *) (src1 + offset_src1);
8181

8282
// pointers to src0 rows
83-
global char * ax[N_R0_Q8_0];
84-
global half * ad[N_R0_Q8_0];
85-
for (int row = 0; row < N_R0_Q8_0; ++row) {
86-
ulong offset_src0 = ((first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/34;
87-
ax[row] = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
88-
ad[row] = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
89-
}
83+
uint offset_src0_base = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
84+
85+
global char * ax0, * ax1, * ax2, * ax3;
86+
global half * ad0, * ad1, * ad2, * ad3;
87+
uint offset_src0;
88+
89+
offset_src0 = offset_src0_base + 0*nb01;
90+
offset_src0 = offset_src0/34;
91+
ax0 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
92+
ad0 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
9093

91-
float yl[NB_Q8_0];
92-
float sumf[N_R0_Q8_0] = { 0.f };
94+
offset_src0 = offset_src0_base + 1*nb01;
95+
offset_src0 = offset_src0/34;
96+
ax1 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
97+
ad1 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
98+
99+
offset_src0 = offset_src0_base + 2*nb01;
100+
offset_src0 = offset_src0/34;
101+
ax2 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
102+
ad2 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
103+
104+
offset_src0 = offset_src0_base + 3*nb01;
105+
offset_src0 = offset_src0/34;
106+
ax3 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
107+
ad3 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
93108

94109
const short ix = get_sub_group_local_id()/4;
95110
const short il = get_sub_group_local_id()%4;
96111

97112
global float * yb = y + ix*QK8_0 + il*NB_Q8_0;
98113

114+
float8 yl;
115+
float8 qv;
116+
float4 sumf = 0.f;
117+
float sumq = 0.f;
118+
global char * qs;
119+
99120
// each thread handles NB_Q8_0 quants at a time
100121
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) {
101-
for (short i = 0; i < NB_Q8_0; ++i) {
102-
yl[i] = yb[i];
103-
}
104-
105-
for (short row = 0; row < N_R0_Q8_0; row++) {
106-
global char * qs = ax[row] + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
107-
float sumq = 0.f;
108-
for (short iq = 0; iq < NB_Q8_0; ++iq) {
109-
sumq += qs[iq] * yl[iq];
110-
}
111-
sumf[row] += sumq*ad[row][ib];
112-
}
122+
yl = vload8(0, yb);
123+
124+
qs = ax0 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
125+
qv = convert_float8(vload8(0, qs));
126+
sumq = 0;
127+
sumq += qv.s0*yl.s0;
128+
sumq += qv.s1*yl.s1;
129+
sumq += qv.s2*yl.s2;
130+
sumq += qv.s3*yl.s3;
131+
sumq += qv.s4*yl.s4;
132+
sumq += qv.s5*yl.s5;
133+
sumq += qv.s6*yl.s6;
134+
sumq += qv.s7*yl.s7;
135+
sumf.s0 += sumq*ad0[ib];
136+
137+
qs = ax1 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
138+
qv = convert_float8(vload8(0, qs));
139+
sumq = 0;
140+
sumq += qv.s0*yl.s0;
141+
sumq += qv.s1*yl.s1;
142+
sumq += qv.s2*yl.s2;
143+
sumq += qv.s3*yl.s3;
144+
sumq += qv.s4*yl.s4;
145+
sumq += qv.s5*yl.s5;
146+
sumq += qv.s6*yl.s6;
147+
sumq += qv.s7*yl.s7;
148+
sumf.s1 += sumq*ad1[ib];
149+
150+
qs = ax2 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
151+
qv = convert_float8(vload8(0, qs));
152+
sumq = 0;
153+
sumq += qv.s0*yl.s0;
154+
sumq += qv.s1*yl.s1;
155+
sumq += qv.s2*yl.s2;
156+
sumq += qv.s3*yl.s3;
157+
sumq += qv.s4*yl.s4;
158+
sumq += qv.s5*yl.s5;
159+
sumq += qv.s6*yl.s6;
160+
sumq += qv.s7*yl.s7;
161+
sumf.s2 += sumq*ad2[ib];
162+
163+
qs = ax3 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
164+
qv = convert_float8(vload8(0, qs));
165+
sumq = 0;
166+
sumq += qv.s0*yl.s0;
167+
sumq += qv.s1*yl.s1;
168+
sumq += qv.s2*yl.s2;
169+
sumq += qv.s3*yl.s3;
170+
sumq += qv.s4*yl.s4;
171+
sumq += qv.s5*yl.s5;
172+
sumq += qv.s6*yl.s6;
173+
sumq += qv.s7*yl.s7;
174+
sumf.s3 += sumq*ad3[ib];
113175

114176
yb += N_SIMDWIDTH*NB_Q8_0;
115177
}
116178

117179
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
118180

119-
for (int row = 0; row < N_R0_Q8_0; ++row) {
120-
float tot = sub_group_reduce_add(sumf[row]);
181+
float4 tot = (float4)(
182+
sub_group_reduce_add(sumf.s0),
183+
sub_group_reduce_add(sumf.s1),
184+
sub_group_reduce_add(sumf.s2),
185+
sub_group_reduce_add(sumf.s3)
186+
);
121187

122-
if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
123-
dst_f32[first_row + row] = tot;
188+
if (get_sub_group_local_id() == 0) {
189+
if (first_row + 0 < ne01) {
190+
dst_f32[first_row + 0] = tot.s0;
191+
}
192+
if (first_row + 1 < ne01) {
193+
dst_f32[first_row + 1] = tot.s1;
194+
}
195+
if (first_row + 2 < ne01) {
196+
dst_f32[first_row + 2] = tot.s2;
197+
}
198+
if (first_row + 3 < ne01) {
199+
dst_f32[first_row + 3] = tot.s3;
124200
}
125201
}
126202
}

0 commit comments

Comments
 (0)