@@ -6251,9 +6251,9 @@ void kernel_mul_mv_q4_K_f32_impl(
62516251 ushort sgitg) {
62526252 const short NSG = FC_mul_mv_nsg;
62536253
6254- const uint16_t kmask1 = 0x3f3f ;
6255- const uint16_t kmask2 = 0x0f0f ;
6256- const uint16_t kmask3 = 0xc0c0 ;
6254+ constexpr uint16_t kmask1 = 0x3f3f ;
6255+ constexpr uint16_t kmask2 = 0x0f0f ;
6256+ constexpr uint16_t kmask3 = 0xc0c0 ;
62576257
62586258 const short ix = tiisg/8 ; // 0...3
62596259 const short it = tiisg%8 ; // 0...7
@@ -6312,7 +6312,7 @@ void kernel_mul_mv_q4_K_f32_impl(
63126312 float4 acc1 = {0 .f , 0 .f , 0 .f , 0 .f };
63136313 float4 acc2 = {0 .f , 0 .f , 0 .f , 0 .f };
63146314
6315- for (short i = 0 ; i < 4 ; ++i) {
6315+ FOR_UNROLL (short i = 0 ; i < 4 ; ++i) {
63166316 acc1[0 ] += yl[2 *i + 0 ] * (q1[i] & 0x000F );
63176317 acc1[1 ] += yl[2 *i + 1 ] * (q1[i] & 0x0F00 );
63186318 acc1[2 ] += yl[2 *i + 8 ] * (q1[i] & 0x00F0 );
@@ -6323,14 +6323,11 @@ void kernel_mul_mv_q4_K_f32_impl(
63236323 acc2[3 ] += yh[2 *i + 9 ] * (q2[i] & 0xF000 );
63246324 }
63256325
6326- float dall = dh[0 ];
6327- float dmin = dh[1 ];
6328-
6329- sumf[row] += dall * ((acc1[0 ] + 1 .f /256 .f * acc1[1 ]) * sc8[0 ] +
6330- (acc1[2 ] + 1 .f /256 .f * acc1[3 ]) * sc8[1 ] * 1 .f /16 .f +
6331- (acc2[0 ] + 1 .f /256 .f * acc2[1 ]) * sc8[4 ] +
6332- (acc2[2 ] + 1 .f /256 .f * acc2[3 ]) * sc8[5 ] * 1 .f /16 .f ) -
6333- dmin * (sumy[0 ] * sc8[2 ] + sumy[1 ] * sc8[3 ] + sumy[2 ] * sc8[6 ] + sumy[3 ] * sc8[7 ]);
6326+ sumf[row] += dh[0 ] * ((acc1[0 ] + 1 .f /256 .f * acc1[1 ]) * sc8[0 ] +
6327+ (acc1[2 ] + 1 .f /256 .f * acc1[3 ]) * sc8[1 ] * 1 .f /16 .f +
6328+ (acc2[0 ] + 1 .f /256 .f * acc2[1 ]) * sc8[4 ] +
6329+ (acc2[2 ] + 1 .f /256 .f * acc2[3 ]) * sc8[5 ] * 1 .f /16 .f ) -
6330+ dh[1 ] * (sumy[0 ] * sc8[2 ] + sumy[1 ] * sc8[3 ] + sumy[2 ] * sc8[6 ] + sumy[3 ] * sc8[7 ]);
63346331
63356332 q1 += args.nb01 /2 ;
63366333 sc += args.nb01 /2 ;
@@ -6396,9 +6393,9 @@ void kernel_mul_mv_q5_K_f32_impl(
63966393
63976394 float yl[16 ], yh[16 ];
63986395
6399- const uint16_t kmask1 = 0x3f3f ;
6400- const uint16_t kmask2 = 0x0f0f ;
6401- const uint16_t kmask3 = 0xc0c0 ;
6396+ constexpr uint16_t kmask1 = 0x3f3f ;
6397+ constexpr uint16_t kmask2 = 0x0f0f ;
6398+ constexpr uint16_t kmask3 = 0xc0c0 ;
64026399
64036400 const short tid = tiisg/4 ;
64046401 const short ix = tiisg%4 ;
@@ -6444,7 +6441,7 @@ void kernel_mul_mv_q5_K_f32_impl(
64446441
64456442 float4 acc1 = {0 .f };
64466443 float4 acc2 = {0 .f };
6447- for (short l = 0 ; l < 8 ; ++l) {
6444+ FOR_UNROLL (short l = 0 ; l < 8 ; ++l) {
64486445 uint8_t h = qh[l];
64496446 acc1[0 ] += yl[l+0 ] * (q1[l] & 0x0F );
64506447 acc1[1 ] += yl[l+8 ] * (q1[l] & 0xF0 );
@@ -6455,13 +6452,12 @@ void kernel_mul_mv_q5_K_f32_impl(
64556452 acc2[2 ] += h & hm3 ? yh[l+0 ] : 0 .f ;
64566453 acc2[3 ] += h & hm4 ? yh[l+8 ] : 0 .f ;
64576454 }
6458- const float dall = dh[0 ];
6459- const float dmin = dh[1 ];
6460- sumf[row] += dall * (sc8[0 ] * (acc1[0 ] + 16 .f *acc2[0 ]) +
6461- sc8[1 ] * (acc1[1 ]/16 .f + 16 .f *acc2[1 ]) +
6462- sc8[4 ] * (acc1[2 ] + 16 .f *acc2[2 ]) +
6463- sc8[5 ] * (acc1[3 ]/16 .f + 16 .f *acc2[3 ])) -
6464- dmin * (sumy[0 ] * sc8[2 ] + sumy[1 ] * sc8[3 ] + sumy[2 ] * sc8[6 ] + sumy[3 ] * sc8[7 ]);
6455+
6456+ sumf[row] += dh[0 ] * (sc8[0 ] * (acc1[0 ] + 16 .f *acc2[0 ]) +
6457+ sc8[1 ] * (acc1[1 ]/16 .f + 16 .f *acc2[1 ]) +
6458+ sc8[4 ] * (acc1[2 ] + 16 .f *acc2[2 ]) +
6459+ sc8[5 ] * (acc1[3 ]/16 .f + 16 .f *acc2[3 ])) -
6460+ dh[1 ] * (sumy[0 ] * sc8[2 ] + sumy[1 ] * sc8[3 ] + sumy[2 ] * sc8[6 ] + sumy[3 ] * sc8[7 ]);
64656461
64666462 q1 += args.nb01 ;
64676463 qh += args.nb01 ;
@@ -6507,10 +6503,10 @@ void kernel_mul_mv_q6_K_f32_impl(
65076503 ushort sgitg) {
65086504 const short NSG = FC_mul_mv_nsg;
65096505
6510- const uint8_t kmask1 = 0x03 ;
6511- const uint8_t kmask2 = 0x0C ;
6512- const uint8_t kmask3 = 0x30 ;
6513- const uint8_t kmask4 = 0xC0 ;
6506+ constexpr uint8_t kmask1 = 0x03 ;
6507+ constexpr uint8_t kmask2 = 0x0C ;
6508+ constexpr uint8_t kmask3 = 0x30 ;
6509+ constexpr uint8_t kmask4 = 0xC0 ;
65146510
65156511 const int nb = args.ne00 /QK_K;
65166512
@@ -6561,18 +6557,16 @@ void kernel_mul_mv_q6_K_f32_impl(
65616557 }
65626558
65636559 for (short row = 0 ; row < nr0; ++row) {
6564- const float dall = dh[0 ];
6565-
65666560 float4 sums = {0 .f , 0 .f , 0 .f , 0 .f };
65676561
6568- for (short l = 0 ; l < 4 ; ++l) {
6562+ FOR_UNROLL (short l = 0 ; l < 4 ; ++l) {
65696563 sums[0 ] += yl[4 *l + 0 ] * ((int8_t )((q1[l] & 0xF ) | ((qh[l] & kmask1) << 4 )) - 32 );
65706564 sums[1 ] += yl[4 *l + 1 ] * ((int8_t )((q2[l] & 0xF ) | ((qh[l] & kmask2) << 2 )) - 32 );
65716565 sums[2 ] += yl[4 *l + 2 ] * ((int8_t )((q1[l] >> 4 ) | ((qh[l] & kmask3) << 0 )) - 32 );
65726566 sums[3 ] += yl[4 *l + 3 ] * ((int8_t )((q2[l] >> 4 ) | ((qh[l] & kmask4) >> 2 )) - 32 );
65736567 }
65746568
6575- sumf[row] += dall * (sums[0 ] * sc[0 ] + sums[1 ] * sc[2 ] + sums[2 ] * sc[4 ] + sums[3 ] * sc[6 ]);
6569+ sumf[row] += dh[ 0 ] * (sums[0 ] * sc[0 ] + sums[1 ] * sc[2 ] + sums[2 ] * sc[4 ] + sums[3 ] * sc[6 ]);
65766570
65776571 q1 += args.nb01 ;
65786572 q2 += args.nb01 ;
0 commit comments