@@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
373373template <typename type4x4>
374374void dequantize_q6_K (device const block_q6_K *xb, short il, thread type4x4 & reg) {
375375 const half d_all = xb->d ;
376- device const uint8_t * ql = (device const uint8_t *)xb->ql ;
377- device const uint8_t * qh = (device const uint8_t *)xb->qh ;
376+ device const uint16_t * ql = (device const uint16_t *)xb->ql ;
377+ device const uint16_t * qh = (device const uint16_t *)xb->qh ;
378378 device const int8_t * scales = (device const int8_t *)xb->scales ;
379379
380- ql = ql + 64 *(il/8 ) + 32 *((il/2 )&1 ) + 16 *(il&1 );
381- qh = qh + 32 *(il/8 ) + 16 *(il&1 );
380+ ql = ql + 32 *(il/8 ) + 16 *((il/2 )&1 ) + 8 *(il&1 );
381+ qh = qh + 16 *(il/8 ) + 8 *(il&1 );
382382 float sc = scales[(il%2 ) + 2 * ((il/2 ))];
383383 il = (il/2 ) & 3 ;
384384
385- const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48 ) : (il>0 ? 12 : 3 );
386- const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F ;
387- const float coef = il>1 ? 1 .f /16 .f : 1 .f ;
385+ const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030 ) : (il>0 ? 0x0C0C0C0C : 0x03030303 );
386+ const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F ;
388387 const float ml = d_all * sc * 32 .f ;
389- const float dl = d_all * sc * coef;
390- for (int i = 0 ; i < 16 ; ++i) {
391- const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2 ))
392- : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4 ));
393- reg[i/4 ][i%4 ] = dl * q - ml;
388+ const float dl0 = d_all * sc;
389+ const float dl1 = dl0 / 256 .f ;
390+ const float dl2 = dl0 / (256 .f * 256 .f );
391+ const float dl3 = dl0 / (256 .f * 256 .f * 256 .f );
392+ const uint8_t shr_h = il>2 ? 2 : 0 ;
393+ const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4 );
394+ const uint8_t shr_l = il>1 ? 4 : 0 ;
395+ for (int i = 0 ; i < 4 ; ++i) {
396+ const uint32_t low = (ql[2 *i] | (uint32_t )(ql[2 *i+1 ] << 16 )) & kmask2;
397+ const uint32_t high = (qh[2 *i] | (uint32_t )(qh[2 *i+1 ] << 16 )) & kmask1;
398+ const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
399+ reg[i][0 ] = dl0 * ((half)(q & 0xFF )) - ml;
400+ reg[i][1 ] = dl1 * ((float )(q & 0xFF00 )) - ml;
401+ reg[i][2 ] = dl2 * ((float )(q & 0xFF0000 )) - ml;
402+ reg[i][3 ] = dl3 * ((float )(q & 0xFF000000 )) - ml;
394403 }
395404}
396405
0 commit comments