@@ -4975,9 +4975,9 @@ void kernel_mul_mv_q6_K_f32_impl(
49754975 device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
49764976 device const float * yy = (device const float *) (src1 + offset1);
49774977
4978- // TODO: support nr0 > 1
4979- static_assert (nr0 == 1 , " nr0 > 1 not supported " );
4980- float sumf[ 1 ] = { 0 . f } ;
4978+ float sumf[nr0] = { 0 . f };
4979+
4980+ float yl[ 16 ] ;
49814981
49824982 const short tid = tiisg/2 ;
49834983 const short ix = tiisg%2 ;
@@ -4995,22 +4995,37 @@ void kernel_mul_mv_q6_K_f32_impl(
49954995 device const uint8_t * q2 = q1 + 32 ;
49964996 device const uint8_t * qh = x[i].qh + q_offset_h;
49974997 device const int8_t * sc = x[i].scales + is;
4998+ device const half * dh = &x[i].d ;
49984999
49995000 device const float * y = yy + i * QK_K + y_offset;
50005001
5001- const float dall = x[i].d ;
5002-
5003- float4 sums = {0 .f , 0 .f , 0 .f , 0 .f };
5004-
5005- #pragma unroll(4)
50065002 for (short l = 0 ; l < 4 ; ++l) {
5007- sums[ 0 ] += y[l+ 0 ] * (( int8_t )((q1[l] & 0xF ) | ((qh[l] & kmask1) << 4 )) - 32 ) ;
5008- sums[ 1 ] + = y[l+ 32 ] * (( int8_t )((q2[l] & 0xF ) | ((qh[l] & kmask2) << 2 )) - 32 ) ;
5009- sums[ 2 ] + = y[l+ 64 ] * (( int8_t )((q1[l] >> 4 ) | ((qh[l] & kmask3) << 0 )) - 32 ) ;
5010- sums[ 3 ] + = y[l+ 96 ] * (( int8_t )((q2[l] >> 4 ) | ((qh[l] & kmask4) >> 2 )) - 32 ) ;
5003+ yl[ 4 *l + 0 ] = y[l + 0 ] ;
5004+ yl[ 4 *l + 1 ] = y[l + 32 ] ;
5005+ yl[ 4 *l + 2 ] = y[l + 64 ];
5006+ yl[ 4 *l + 3 ] = y[l + 96 ];
50115007 }
50125008
5013- sumf[0 ] += dall * (sums[0 ] * sc[0 ] + sums[1 ] * sc[2 ] + sums[2 ] * sc[4 ] + sums[3 ] * sc[6 ]);
5009+ for (short row = 0 ; row < nr0; ++row) {
5010+ const float dall = dh[0 ];
5011+
5012+ float4 sums = {0 .f , 0 .f , 0 .f , 0 .f };
5013+
5014+ for (short l = 0 ; l < 4 ; ++l) {
5015+ sums[0 ] += yl[4 *l + 0 ] * ((int8_t )((q1[l] & 0xF ) | ((qh[l] & kmask1) << 4 )) - 32 );
5016+ sums[1 ] += yl[4 *l + 1 ] * ((int8_t )((q2[l] & 0xF ) | ((qh[l] & kmask2) << 2 )) - 32 );
5017+ sums[2 ] += yl[4 *l + 2 ] * ((int8_t )((q1[l] >> 4 ) | ((qh[l] & kmask3) << 0 )) - 32 );
5018+ sums[3 ] += yl[4 *l + 3 ] * ((int8_t )((q2[l] >> 4 ) | ((qh[l] & kmask4) >> 2 )) - 32 );
5019+ }
5020+
5021+ sumf[row] += dall * (sums[0 ] * sc[0 ] + sums[1 ] * sc[2 ] + sums[2 ] * sc[4 ] + sums[3 ] * sc[6 ]);
5022+
5023+ q1 += args.nb01 ;
5024+ q2 += args.nb01 ;
5025+ qh += args.nb01 ;
5026+ sc += args.nb01 ;
5027+ dh += args.nb01 /2 ;
5028+ }
50145029 }
50155030
50165031 device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
0 commit comments