@@ -68,8 +68,8 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
6868    src2  =  (global  char  * )((global  char  * )src2  +  offset2 );
6969    dst   =  (global  char  * )((global  char  * )dst   +  offsetd );
7070
71-     int  iid1  =  get_group_id (2 )/ne20 ;
72-     int  idx   =  get_group_id (2 )%ne20 ;
71+     int  iid1  =  ( int ) get_group_id (2 )/ne20 ;
72+     int  idx   =  ( int ) get_group_id (2 )%ne20 ;
7373
7474    int  i02  =  ((global  int  * ) (src2  +  iid1 * nb21 ))[idx ];
7575
@@ -80,7 +80,8 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
8080    int  i2  =  i12_ ;
8181
8282    // 34 == sizeof(block_q8_0) 
83-     ulong  src0_off  =  i02 * nb02 /34 ;
83+     uint  src0_off  =  i02 * nb02 ;
84+     src0_off  /= 34 ;
8485
8586    global  char  *  src0_q_cur  =  src0_q  +  src0_off * sizeof (char )* QK8_0 ;
8687    global  half  *  src0_d_cur  =  src0_d  +  src0_off ;
@@ -99,47 +100,123 @@ kernel void kernel_mul_mv_id_q8_0_f32_flat(
99100    global  float  *  y   =  (global  float  * ) (src1_cur  +  offset_src1 );
100101
101102    // pointers to src0 rows 
102-     global  char  *  ax [N_R0_Q8_0 ];
103-     global  half  *  ad [N_R0_Q8_0 ];
104-     for  (int  row  =  0 ; row  <  N_R0_Q8_0 ; ++ row ) {
105-         ulong  offset_src0  =  (first_row  +  row )* nb01 /34 ;
106-         ax [row ] =  (global  char  * ) ((global  char  * ) src0_q_cur  +  offset_src0 * sizeof (char )* QK8_0 );
107-         ad [row ] =  (global  half  * ) ((global  char  * ) src0_d_cur  +  offset_src0 * sizeof (half ));
108-     }
103+     uint  offset_src0_base  =  first_row * nb01 ;
104+ 
105+     global  char  *  ax0 , *  ax1 , *  ax2 , *  ax3 ;
106+     global  half  *  ad0 , *  ad1 , *  ad2 , *  ad3 ;
107+     uint  offset_src0 ;
108+ 
109+     offset_src0  =  offset_src0_base  +  0 * nb01 ;
110+     offset_src0  =  offset_src0 /34 ;
111+     ax0  =  (global  char  * ) ((global  char  * ) src0_q_cur  +  offset_src0 * sizeof (char )* QK8_0 );
112+     ad0  =  (global  half  * ) ((global  char  * ) src0_d_cur  +  offset_src0 * sizeof (half ));
109113
110-     float  yl [NB_Q8_0 ];
111-     float  sumf [N_R0_Q8_0 ] =  { 0.f  };
114+     offset_src0  =  offset_src0_base  +  1 * nb01 ;
115+     offset_src0  =  offset_src0 /34 ;
116+     ax1  =  (global  char  * ) ((global  char  * ) src0_q_cur  +  offset_src0 * sizeof (char )* QK8_0 );
117+     ad1  =  (global  half  * ) ((global  char  * ) src0_d_cur  +  offset_src0 * sizeof (half ));
118+ 
119+     offset_src0  =  offset_src0_base  +  2 * nb01 ;
120+     offset_src0  =  offset_src0 /34 ;
121+     ax2  =  (global  char  * ) ((global  char  * ) src0_q_cur  +  offset_src0 * sizeof (char )* QK8_0 );
122+     ad2  =  (global  half  * ) ((global  char  * ) src0_d_cur  +  offset_src0 * sizeof (half ));
123+ 
124+     offset_src0  =  offset_src0_base  +  3 * nb01 ;
125+     offset_src0  =  offset_src0 /34 ;
126+     ax3  =  (global  char  * ) ((global  char  * ) src0_q_cur  +  offset_src0 * sizeof (char )* QK8_0 );
127+     ad3  =  (global  half  * ) ((global  char  * ) src0_d_cur  +  offset_src0 * sizeof (half ));
112128
113129    const  short  ix  =  get_sub_group_local_id ()/4 ;
114130    const  short  il  =  get_sub_group_local_id ()%4 ;
115131
116132    global  float  *  yb  =  y  +  ix * QK8_0  +  il * NB_Q8_0 ;
117133
134+     float8  yl ;
135+     float8  qv ;
136+     float4  sumf  =  0.f ;
137+     float   sumq  =  0.f ;
138+     global  char  *  qs ;
139+ 
118140    // each thread handles NB_Q8_0 quants at a time 
119141    for  (int  ib  =  ix ; ib  <  nb ; ib  +=  N_SIMDWIDTH /4 ) {
120-         for  (short  i  =  0 ; i  <  NB_Q8_0 ; ++ i ) {
121-             yl [i ] =  yb [i ];
122-         }
123- 
124-         for  (short  row  =  0 ; row  <  N_R0_Q8_0 ; row ++ ) {
125-             global  char  *  qs  =  ax [row ] +  ib * sizeof (char )* QK8_0  +  il * NB_Q8_0 ;
126-             float  sumq  =  0.f ;
127-             for  (short  iq  =  0 ; iq  <  NB_Q8_0 ; ++ iq ) {
128-                 sumq  +=  qs [iq ] *  yl [iq ];
129-             }
130-             sumf [row ] +=  sumq * ad [row ][ib ];
131-         }
142+         yl  =  vload8 (0 , yb );
143+ 
144+         qs  =  ax0  +  ib * sizeof (char )* QK8_0  +  il * NB_Q8_0 ;
145+         qv  =  convert_float8 (vload8 (0 , qs ));
146+         sumq  =  0 ;
147+         sumq  +=  qv .s0 * yl .s0 ;
148+         sumq  +=  qv .s1 * yl .s1 ;
149+         sumq  +=  qv .s2 * yl .s2 ;
150+         sumq  +=  qv .s3 * yl .s3 ;
151+         sumq  +=  qv .s4 * yl .s4 ;
152+         sumq  +=  qv .s5 * yl .s5 ;
153+         sumq  +=  qv .s6 * yl .s6 ;
154+         sumq  +=  qv .s7 * yl .s7 ;
155+         sumf .s0  +=  sumq * ad0 [ib ];
156+ 
157+         qs  =  ax1  +  ib * sizeof (char )* QK8_0  +  il * NB_Q8_0 ;
158+         qv  =  convert_float8 (vload8 (0 , qs ));
159+         sumq  =  0 ;
160+         sumq  +=  qv .s0 * yl .s0 ;
161+         sumq  +=  qv .s1 * yl .s1 ;
162+         sumq  +=  qv .s2 * yl .s2 ;
163+         sumq  +=  qv .s3 * yl .s3 ;
164+         sumq  +=  qv .s4 * yl .s4 ;
165+         sumq  +=  qv .s5 * yl .s5 ;
166+         sumq  +=  qv .s6 * yl .s6 ;
167+         sumq  +=  qv .s7 * yl .s7 ;
168+         sumf .s1  +=  sumq * ad1 [ib ];
169+ 
170+         qs  =  ax2  +  ib * sizeof (char )* QK8_0  +  il * NB_Q8_0 ;
171+         qv  =  convert_float8 (vload8 (0 , qs ));
172+         sumq  =  0 ;
173+         sumq  +=  qv .s0 * yl .s0 ;
174+         sumq  +=  qv .s1 * yl .s1 ;
175+         sumq  +=  qv .s2 * yl .s2 ;
176+         sumq  +=  qv .s3 * yl .s3 ;
177+         sumq  +=  qv .s4 * yl .s4 ;
178+         sumq  +=  qv .s5 * yl .s5 ;
179+         sumq  +=  qv .s6 * yl .s6 ;
180+         sumq  +=  qv .s7 * yl .s7 ;
181+         sumf .s2  +=  sumq * ad2 [ib ];
182+ 
183+         qs  =  ax3  +  ib * sizeof (char )* QK8_0  +  il * NB_Q8_0 ;
184+         qv  =  convert_float8 (vload8 (0 , qs ));
185+         sumq  =  0 ;
186+         sumq  +=  qv .s0 * yl .s0 ;
187+         sumq  +=  qv .s1 * yl .s1 ;
188+         sumq  +=  qv .s2 * yl .s2 ;
189+         sumq  +=  qv .s3 * yl .s3 ;
190+         sumq  +=  qv .s4 * yl .s4 ;
191+         sumq  +=  qv .s5 * yl .s5 ;
192+         sumq  +=  qv .s6 * yl .s6 ;
193+         sumq  +=  qv .s7 * yl .s7 ;
194+         sumf .s3  +=  sumq * ad3 [ib ];
132195
133196        yb  +=  N_SIMDWIDTH * NB_Q8_0 ;
134197    }
135198
136199    global  float  *  dst_f32  =  (global  float  * ) dst_cur  +  (ulong )r1 * ne0 ;
137200
138-     for  (int  row  =  0 ; row  <  N_R0_Q8_0 ; ++ row ) {
139-         float  tot  =  sub_group_reduce_add (sumf [row ]);
201+     float4  tot  =  (float4 )(
202+         sub_group_reduce_add (sumf .s0 ),
203+         sub_group_reduce_add (sumf .s1 ),
204+         sub_group_reduce_add (sumf .s2 ),
205+         sub_group_reduce_add (sumf .s3 )
206+     );
140207
141-         if  (get_sub_group_local_id () ==  0  &&  first_row  +  row  <  ne01 ) {
142-             dst_f32 [first_row  +  row ] =  tot ;
208+     if  (get_sub_group_local_id () ==  0 ) {
209+         if  (first_row  +  0  <  ne01 ) {
210+             dst_f32 [first_row  +  0 ] =  tot .s0 ;
211+         }
212+         if  (first_row  +  1  <  ne01 ) {
213+             dst_f32 [first_row  +  1 ] =  tot .s1 ;
214+         }
215+         if  (first_row  +  2  <  ne01 ) {
216+             dst_f32 [first_row  +  2 ] =  tot .s2 ;
217+         }
218+         if  (first_row  +  3  <  ne01 ) {
219+             dst_f32 [first_row  +  3 ] =  tot .s3 ;
143220        }
144221    }
145222}
0 commit comments