@@ -68,8 +68,8 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
6868 src2 = (global char * )((global char * )src2 + offset2 );
6969 dst = (global char * )((global char * )dst + offsetd );
7070
71- int iid1 = get_group_id (2 )/ne20 ;
72- int idx = get_group_id (2 )%ne20 ;
71+ int iid1 = ( int ) get_group_id (2 )/ne20 ;
72+ int idx = ( int ) get_group_id (2 )%ne20 ;
7373
7474 int i02 = ((global int * ) (src2 + iid1 * nb21 ))[idx ];
7575
@@ -80,7 +80,8 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
8080 int i2 = i12_ ;
8181
8282 // 34 == sizeof(block_q8_0)
83- ulong src0_off = i02 * nb02 /34 ;
83+ uint src0_off = i02 * nb02 ;
84+ src0_off /= 34 ;
8485
8586 global char * src0_q_cur = src0_q + src0_off * sizeof (char )* QK8_0 ;
8687 global half * src0_d_cur = src0_d + src0_off ;
@@ -99,47 +100,123 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
99100 global float * y = (global float * ) (src1_cur + offset_src1 );
100101
101102 // pointers to src0 rows
102- global char * ax [N_R0_Q8_0 ];
103- global half * ad [N_R0_Q8_0 ];
104- for (int row = 0 ; row < N_R0_Q8_0 ; ++ row ) {
105- ulong offset_src0 = (first_row + row )* nb01 /34 ;
106- ax [row ] = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
107- ad [row ] = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
108- }
103+ uint offset_src0_base = first_row * nb01 ;
104+
105+ global char * ax0 , * ax1 , * ax2 , * ax3 ;
106+ global half * ad0 , * ad1 , * ad2 , * ad3 ;
107+ uint offset_src0 ;
108+
109+ offset_src0 = offset_src0_base + 0 * nb01 ;
110+ offset_src0 = offset_src0 /34 ;
111+ ax0 = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
112+ ad0 = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
109113
110- float yl [NB_Q8_0 ];
111- float sumf [N_R0_Q8_0 ] = { 0.f };
114+ offset_src0 = offset_src0_base + 1 * nb01 ;
115+ offset_src0 = offset_src0 /34 ;
116+ ax1 = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
117+ ad1 = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
118+
119+ offset_src0 = offset_src0_base + 2 * nb01 ;
120+ offset_src0 = offset_src0 /34 ;
121+ ax2 = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
122+ ad2 = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
123+
124+ offset_src0 = offset_src0_base + 3 * nb01 ;
125+ offset_src0 = offset_src0 /34 ;
126+ ax3 = (global char * ) ((global char * ) src0_q_cur + offset_src0 * sizeof (char )* QK8_0 );
127+ ad3 = (global half * ) ((global char * ) src0_d_cur + offset_src0 * sizeof (half ));
112128
113129 const short ix = get_sub_group_local_id ()/4 ;
114130 const short il = get_sub_group_local_id ()%4 ;
115131
116132 global float * yb = y + ix * QK8_0 + il * NB_Q8_0 ;
117133
134+ float8 yl ;
135+ float8 qv ;
136+ float4 sumf = 0.f ;
137+ float sumq = 0.f ;
138+ global char * qs ;
139+
118140 // each thread handles NB_Q8_0 quants at a time
119141 for (int ib = ix ; ib < nb ; ib += N_SIMDWIDTH /4 ) {
120- for (short i = 0 ; i < NB_Q8_0 ; ++ i ) {
121- yl [i ] = yb [i ];
122- }
123-
124- for (short row = 0 ; row < N_R0_Q8_0 ; row ++ ) {
125- global char * qs = ax [row ] + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
126- float sumq = 0.f ;
127- for (short iq = 0 ; iq < NB_Q8_0 ; ++ iq ) {
128- sumq += qs [iq ] * yl [iq ];
129- }
130- sumf [row ] += sumq * ad [row ][ib ];
131- }
142+ yl = vload8 (0 , yb );
143+
144+ qs = ax0 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
145+ qv = convert_float8 (vload8 (0 , qs ));
146+ sumq = 0 ;
147+ sumq += qv .s0 * yl .s0 ;
148+ sumq += qv .s1 * yl .s1 ;
149+ sumq += qv .s2 * yl .s2 ;
150+ sumq += qv .s3 * yl .s3 ;
151+ sumq += qv .s4 * yl .s4 ;
152+ sumq += qv .s5 * yl .s5 ;
153+ sumq += qv .s6 * yl .s6 ;
154+ sumq += qv .s7 * yl .s7 ;
155+ sumf .s0 += sumq * ad0 [ib ];
156+
157+ qs = ax1 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
158+ qv = convert_float8 (vload8 (0 , qs ));
159+ sumq = 0 ;
160+ sumq += qv .s0 * yl .s0 ;
161+ sumq += qv .s1 * yl .s1 ;
162+ sumq += qv .s2 * yl .s2 ;
163+ sumq += qv .s3 * yl .s3 ;
164+ sumq += qv .s4 * yl .s4 ;
165+ sumq += qv .s5 * yl .s5 ;
166+ sumq += qv .s6 * yl .s6 ;
167+ sumq += qv .s7 * yl .s7 ;
168+ sumf .s1 += sumq * ad1 [ib ];
169+
170+ qs = ax2 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
171+ qv = convert_float8 (vload8 (0 , qs ));
172+ sumq = 0 ;
173+ sumq += qv .s0 * yl .s0 ;
174+ sumq += qv .s1 * yl .s1 ;
175+ sumq += qv .s2 * yl .s2 ;
176+ sumq += qv .s3 * yl .s3 ;
177+ sumq += qv .s4 * yl .s4 ;
178+ sumq += qv .s5 * yl .s5 ;
179+ sumq += qv .s6 * yl .s6 ;
180+ sumq += qv .s7 * yl .s7 ;
181+ sumf .s2 += sumq * ad2 [ib ];
182+
183+ qs = ax3 + ib * sizeof (char )* QK8_0 + il * NB_Q8_0 ;
184+ qv = convert_float8 (vload8 (0 , qs ));
185+ sumq = 0 ;
186+ sumq += qv .s0 * yl .s0 ;
187+ sumq += qv .s1 * yl .s1 ;
188+ sumq += qv .s2 * yl .s2 ;
189+ sumq += qv .s3 * yl .s3 ;
190+ sumq += qv .s4 * yl .s4 ;
191+ sumq += qv .s5 * yl .s5 ;
192+ sumq += qv .s6 * yl .s6 ;
193+ sumq += qv .s7 * yl .s7 ;
194+ sumf .s3 += sumq * ad3 [ib ];
132195
133196 yb += N_SIMDWIDTH * NB_Q8_0 ;
134197 }
135198
136199 global float * dst_f32 = (global float * ) dst_cur + (ulong )r1 * ne0 ;
137200
138- for (int row = 0 ; row < N_R0_Q8_0 ; ++ row ) {
139- float tot = sub_group_reduce_add (sumf [row ]);
201+ float4 tot = (float4 )(
202+ sub_group_reduce_add (sumf .s0 ),
203+ sub_group_reduce_add (sumf .s1 ),
204+ sub_group_reduce_add (sumf .s2 ),
205+ sub_group_reduce_add (sumf .s3 )
206+ );
140207
141- if (get_sub_group_local_id () == 0 && first_row + row < ne01 ) {
142- dst_f32 [first_row + row ] = tot ;
208+ if (get_sub_group_local_id () == 0 ) {
209+ if (first_row + 0 < ne01 ) {
210+ dst_f32 [first_row + 0 ] = tot .s0 ;
211+ }
212+ if (first_row + 1 < ne01 ) {
213+ dst_f32 [first_row + 1 ] = tot .s1 ;
214+ }
215+ if (first_row + 2 < ne01 ) {
216+ dst_f32 [first_row + 2 ] = tot .s2 ;
217+ }
218+ if (first_row + 3 < ne01 ) {
219+ dst_f32 [first_row + 3 ] = tot .s3 ;
143220 }
144221 }
145222}
0 commit comments