@@ -68,8 +68,8 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
68
68
src2 = (global char * )((global char * )src2 + offset2 );
69
69
dst = (global char * )((global char * )dst + offsetd );
70
70
71
- int iid1 = get_group_id (2 )/ne20 ;
72
- int idx = get_group_id (2 )%ne20 ;
71
+ int iid1 = ( int ) get_group_id (2 )/ne20 ;
72
+ int idx = ( int ) get_group_id (2 )%ne20 ;
73
73
74
74
int i02 = ((global int * ) (src2 + iid1 * nb21 ))[idx ];
75
75
@@ -80,7 +80,8 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
80
80
int i2 = i12_ ;
81
81
82
82
// 34 == sizeof(block_q8_0)
83
- ulong src0_off = i02 * nb02 /34 ;
83
+ uint src0_off = i02 * nb02 ;
84
+ src0_off /= 34 ;
84
85
85
86
global char * src0_q_cur = src0_q + src0_off * sizeof (char )* QK8_0 ;
86
87
global half * src0_d_cur = src0_d + src0_off ;
@@ -99,47 +100,123 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
99
100
global float * y = (global float * ) (src1_cur + offset_src1 );
100
101
101
102
// pointers to src0 rows
102
- global char * ax [N_R0_Q8_0 ];
103
- global half * ad [N_R0_Q8_0 ];
104
- for (int row = 0 ; row < N_R0_Q8_0 ; ++ row ) {
105
- ulong offset_src0 = (first_row + row )* nb01 /34 ;
106
- ax [row ] = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
107
- ad [row ] = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
108
- }
103
+ uint offset_src0_base = first_row * nb01 ;
104
+
105
+ global char * ax0 , * ax1 , * ax2 , * ax3 ;
106
+ global half * ad0 , * ad1 , * ad2 , * ad3 ;
107
+ uint offset_src0 ;
108
+
109
+ offset_src0 = offset_src0_base + 0 * nb01 ;
110
+ offset_src0 = offset_src0 /34 ;
111
+ ax0 = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
112
+ ad0 = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
109
113
110
- float yl [NB_Q8_0 ];
111
- float sumf [N_R0_Q8_0 ] = { 0.f };
114
+ offset_src0 = offset_src0_base + 1 * nb01 ;
115
+ offset_src0 = offset_src0 /34 ;
116
+ ax1 = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
117
+ ad1 = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
118
+
119
+ offset_src0 = offset_src0_base + 2 * nb01 ;
120
+ offset_src0 = offset_src0 /34 ;
121
+ ax2 = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
122
+ ad2 = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
123
+
124
+ offset_src0 = offset_src0_base + 3 * nb01 ;
125
+ offset_src0 = offset_src0 /34 ;
126
+ ax3 = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
127
+ ad3 = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
112
128
113
129
const short ix = get_sub_group_local_id ()/4 ;
114
130
const short il = get_sub_group_local_id ()%4 ;
115
131
116
132
global float * yb = y + ix * QK8_0 + il * NB_Q8_0 ;
117
133
134
+ float8 yl ;
135
+ float8 qv ;
136
+ float4 sumf = 0.f ;
137
+ float sumq = 0.f ;
138
+ global char * qs ;
139
+
118
140
// each thread handles NB_Q8_0 quants at a time
119
141
for (int ib = ix ; ib < nb ; ib += N_SIMDWIDTH /4 ) {
120
- for (short i = 0 ; i < NB_Q8_0 ; ++ i ) {
121
- yl [i ] = yb [i ];
122
- }
123
-
124
- for (short row = 0 ; row < N_R0_Q8_0 ; row ++ ) {
125
- global char * qs = ax [row ] + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
126
- float sumq = 0.f ;
127
- for (short iq = 0 ; iq < NB_Q8_0 ; ++ iq ) {
128
- sumq += qs [iq ] * yl [iq ];
129
- }
130
- sumf [row ] += sumq * ad [row ][ib ];
131
- }
142
+ yl = vload8 (0 , yb );
143
+
144
+ qs = ax0 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
145
+ qv = convert_float8 (vload8 (0 , qs ));
146
+ sumq = 0 ;
147
+ sumq += qv .s0 * yl .s0 ;
148
+ sumq += qv .s1 * yl .s1 ;
149
+ sumq += qv .s2 * yl .s2 ;
150
+ sumq += qv .s3 * yl .s3 ;
151
+ sumq += qv .s4 * yl .s4 ;
152
+ sumq += qv .s5 * yl .s5 ;
153
+ sumq += qv .s6 * yl .s6 ;
154
+ sumq += qv .s7 * yl .s7 ;
155
+ sumf .s0 += sumq * ad0 [ib ];
156
+
157
+ qs = ax1 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
158
+ qv = convert_float8 (vload8 (0 , qs ));
159
+ sumq = 0 ;
160
+ sumq += qv .s0 * yl .s0 ;
161
+ sumq += qv .s1 * yl .s1 ;
162
+ sumq += qv .s2 * yl .s2 ;
163
+ sumq += qv .s3 * yl .s3 ;
164
+ sumq += qv .s4 * yl .s4 ;
165
+ sumq += qv .s5 * yl .s5 ;
166
+ sumq += qv .s6 * yl .s6 ;
167
+ sumq += qv .s7 * yl .s7 ;
168
+ sumf .s1 += sumq * ad1 [ib ];
169
+
170
+ qs = ax2 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
171
+ qv = convert_float8 (vload8 (0 , qs ));
172
+ sumq = 0 ;
173
+ sumq += qv .s0 * yl .s0 ;
174
+ sumq += qv .s1 * yl .s1 ;
175
+ sumq += qv .s2 * yl .s2 ;
176
+ sumq += qv .s3 * yl .s3 ;
177
+ sumq += qv .s4 * yl .s4 ;
178
+ sumq += qv .s5 * yl .s5 ;
179
+ sumq += qv .s6 * yl .s6 ;
180
+ sumq += qv .s7 * yl .s7 ;
181
+ sumf .s2 += sumq * ad2 [ib ];
182
+
183
+ qs = ax3 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
184
+ qv = convert_float8 (vload8 (0 , qs ));
185
+ sumq = 0 ;
186
+ sumq += qv .s0 * yl .s0 ;
187
+ sumq += qv .s1 * yl .s1 ;
188
+ sumq += qv .s2 * yl .s2 ;
189
+ sumq += qv .s3 * yl .s3 ;
190
+ sumq += qv .s4 * yl .s4 ;
191
+ sumq += qv .s5 * yl .s5 ;
192
+ sumq += qv .s6 * yl .s6 ;
193
+ sumq += qv .s7 * yl .s7 ;
194
+ sumf .s3 += sumq * ad3 [ib ];
132
195
133
196
yb += N_SIMDWIDTH * NB_Q8_0 ;
134
197
}
135
198
136
199
global float * dst_f32 = (global float * ) dst_cur + (ulong )r1 * ne0 ;
137
200
138
- for (int row = 0 ; row < N_R0_Q8_0 ; ++ row ) {
139
- float tot = sub_group_reduce_add (sumf [row ]);
201
+ float4 tot = (float4 )(
202
+ sub_group_reduce_add (sumf .s0 ),
203
+ sub_group_reduce_add (sumf .s1 ),
204
+ sub_group_reduce_add (sumf .s2 ),
205
+ sub_group_reduce_add (sumf .s3 )
206
+ );
140
207
141
- if (get_sub_group_local_id () == 0 && first_row + row < ne01 ) {
142
- dst_f32 [first_row + row ] = tot ;
208
+ if (get_sub_group_local_id () == 0 ) {
209
+ if (first_row + 0 < ne01 ) {
210
+ dst_f32 [first_row + 0 ] = tot .s0 ;
211
+ }
212
+ if (first_row + 1 < ne01 ) {
213
+ dst_f32 [first_row + 1 ] = tot .s1 ;
214
+ }
215
+ if (first_row + 2 < ne01 ) {
216
+ dst_f32 [first_row + 2 ] = tot .s2 ;
217
+ }
218
+ if (first_row + 3 < ne01 ) {
219
+ dst_f32 [first_row + 3 ] = tot .s3 ;
143
220
}
144
221
}
145
222
}
0 commit comments