@@ -1441,18 +1441,18 @@ kernel void kernel_group_norm(
14411441inline float block_q_n_dot_y (device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
14421442 float d = qb_curr->d ;
14431443
1444- float acc[ 4 ] = { 0 .0f , 0 . 0f , 0 . 0f , 0 . 0f } ;
1444+ float acc = - 8 .0f *sumy ;
14451445
14461446 device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2 );
14471447
1448- for (int i = 0 ; i < 8 ; i += 2 ) {
1449- acc[ 0 ] += yl[i + 0 ] * (qs[i / 2 ] & 0x000F );
1450- acc[ 1 ] += yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
1451- acc[ 2 ] += yl[i + 8 ] * (qs[i / 2 ] & 0x00F0 );
1452- acc[ 3 ] += yl[i + 9 ] * (qs[i / 2 ] & 0xF000 );
1448+ for (short i = 0 ; i < 4 ; ++i ) {
1449+ acc += yl[2 * i + 0 ] * (qs[i] & 0x000F );
1450+ acc += yl[2 * i + 1 ] * (qs[i] & 0x0F00 );
1451+ acc += yl[2 * i + 8 ] * (qs[i] & 0x00F0 );
1452+ acc += yl[2 * i + 9 ] * (qs[i] & 0xF000 );
14531453 }
14541454
1455- return d * (sumy * - 8 . f + acc[ 0 ] + acc[ 1 ] + acc[ 2 ] + acc[ 3 ]) ;
1455+ return d * acc;
14561456}
14571457
14581458// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1567,37 +1567,36 @@ void mul_vec_q_n_f32_impl(
15671567 float yl[16 ]; // src1 vector cache
15681568 float sumf[nr] = {0 .f };
15691569
1570- const short ix = (tiisg/ 2 );
1571- const short il = (tiisg% 2 )*8 ;
1570+ const short ix = (tiisg% 16 );
1571+ const short il = (tiisg/ 16 )*8 ;
15721572
15731573 device const float * yb = y + ix*QK4_0 + il;
15741574
15751575 // each thread in a SIMD group deals with half a block.
15761576 for (int ib = ix; ib < nb; ib += nw/2 ) {
1577- float sumy[ 2 ] = { 0 . f , 0 . f } ;
1577+ float sumy = 0 . 0f ;
15781578
1579- #pragma unroll
1579+ #pragma unroll(4)
15801580 for (int i = 0 ; i < 8 ; i += 2 ) {
1581- sumy[0 ] += yb[i + 0 ] + yb[i + 1 ];
1581+ sumy += yb[i + 0 ] + yb[i + 1 ] + yb[i + 16 ] + yb[i + 17 ];
1582+
15821583 yl[i + 0 ] = yb[i + 0 ];
15831584 yl[i + 1 ] = yb[i + 1 ]/256 .f ;
1584-
1585- sumy[1 ] += yb[i + 16 ] + yb[i + 17 ];
15861585 yl[i + 8 ] = yb[i + 16 ]/16 .f ;
15871586 yl[i + 9 ] = yb[i + 17 ]/4096 .f ;
15881587 }
15891588
1590- #pragma unroll
1591- for (int row = 0 ; row < nr; row++) {
1592- sumf[row] += block_q_n_dot_y (ax[row] + ib, sumy[ 0 ] + sumy[ 1 ] , yl, il);
1589+ #pragma unroll(nr)
1590+ for (short row = 0 ; row < nr; row++) {
1591+ sumf[row] += block_q_n_dot_y (ax[row] + ib, sumy, yl, il);
15931592 }
15941593
15951594 yb += QK4_0 * 16 ;
15961595 }
15971596
15981597 device float * dst_f32 = (device float *) dst + im*args.ne0 *args.ne1 + r1*args.ne0 ;
15991598
1600- for (int row = 0 ; row < nr; ++row) {
1599+ for (short row = 0 ; row < nr; ++row) {
16011600 const float tot = simd_sum (sumf[row]);
16021601
16031602 if (tiisg == 0 && first_row + row < args.ne01 ) {
0 commit comments