@@ -2560,12 +2560,16 @@ typedef struct {
25602560 uint8_t qs[QK4_NL/2 ];
25612561} block_iq4_nl;
25622562
2563+ #if QK_K == 64
2564+ #define block_iq4_xs block_iq4_nl
2565+ #else
25632566typedef struct {
25642567 half d;
25652568 uint16_t scales_h;
25662569 uint8_t scales_l[QK_K/64 ];
25672570 uint8_t qs[QK_K/2 ];
25682571} block_iq4_xs;
2572+ #endif
25692573
25702574// ====================================== dot products =========================
25712575
@@ -4346,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
43464350 threadgroup_barrier (mem_flags::mem_threadgroup);
43474351 }
43484352
4349- #if QK_K == 256
43504353 const int ix = tiisg;
43514354
43524355 device const float * y4 = y + 32 * ix;
@@ -4387,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
43874390
43884391 y4 += 32 * 32 ;
43894392 }
4390- #else
4391- (void ) x;
4392- (void ) y;
4393- (void ) yl;
4394- (void ) nb32;
4395- #endif
43964393
43974394 for (int row = 0 ; row < N_DST; ++row) {
43984395 all_sum = simd_sum (sumf[row]);
@@ -4482,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
44824479 threadgroup_barrier (mem_flags::mem_threadgroup);
44834480 }
44844481
4485- #if QK_K == 256
44864482 const int ix = tiisg;
44874483
44884484 device const float * y4 = y + 32 * ix;
@@ -4533,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
45334529
45344530 y4 += 32 * 32 ;
45354531 }
4536- #else
4537- (void ) x;
4538- (void ) y;
4539- (void ) yl;
4540- (void ) nb32;
4541- #endif
45424532
45434533 for (int row = 0 ; row < N_DST; ++row) {
45444534 all_sum = simd_sum (sumf[row]);
@@ -4628,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
46284618 threadgroup_barrier (mem_flags::mem_threadgroup);
46294619 }
46304620
4631- #if QK_K == 256
46324621 const int ix = tiisg;
46334622
46344623 device const float * y4 = y + 32 * ix;
@@ -4672,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
46724661
46734662 y4 += 32 * 32 ;
46744663 }
4675- #else
4676- (void ) x;
4677- (void ) y;
4678- (void ) yl;
4679- (void ) nb32;
4680- #endif
46814664
46824665 for (int row = 0 ; row < N_DST; ++row) {
46834666 all_sum = simd_sum (sumf[row]);
@@ -5016,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
50164999
50175000 const int nb32 = nb * (QK_K / 32 );
50185001
5019- #if QK_K == 256
50205002 const int ix = tiisg/2 ;
50215003 const int il = tiisg%2 ;
50225004
@@ -5055,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
50555037
50565038 y4 += 16 * 32 ;
50575039 }
5058- #else
5059- (void ) x;
5060- (void ) y;
5061- (void ) yl;
5062- (void ) nb32;
5063- #endif
50645040
50655041 for (int row = 0 ; row < N_DST; ++row) {
50665042 all_sum = simd_sum (sumf[row]);
@@ -5167,6 +5143,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
51675143 }
51685144}
51695145
5146+ #if QK_K != 64
51705147void kernel_mul_mv_iq4_xs_f32_impl (
51715148 device const void * src0,
51725149 device const float * src1,
@@ -5260,6 +5237,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
52605237 }
52615238 }
52625239}
5240+ #endif
52635241
52645242[[host_name(" kernel_mul_mv_iq1_s_f32" )]]
52655243kernel void kernel_mul_mv_iq1_s_f32 (
@@ -5344,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_xs_f32(
53445322 uint tiisg[[thread_index_in_simdgroup]],
53455323 uint sgitg[[simdgroup_index_in_threadgroup]]) {
53465324
5325+ #if QK_K == 64
5326+ kernel_mul_mv_iq4_nl_f32_impl (src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5327+ #else
53475328 kernel_mul_mv_iq4_xs_f32_impl (src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5329+ #endif
53485330}
53495331
53505332// ============================= templates and their specializations =============================
@@ -5770,6 +5752,9 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
57705752
57715753template <typename type4x4>
57725754void dequantize_iq4_xs (device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
5755+ #if QK_K == 64
5756+ dequantize_iq4_nl (xb, il, reg);
5757+ #else
57735758 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
57745759 const int ib32 = il/2 ;
57755760 il = il%2 ;
@@ -5786,6 +5771,7 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
57865771 reg[i][2 ] = d * kvalues_iq4nl_f[q8[2 ]];
57875772 reg[i][3 ] = d * kvalues_iq4nl_f[q8[3 ]];
57885773 }
5774+ #endif
57895775}
57905776
57915777template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &)>
@@ -6334,7 +6320,11 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r
63346320template [[host_name(" kernel_get_rows_iq2_s" )]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
63356321template [[host_name(" kernel_get_rows_iq1_s" )]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
63366322template [[host_name(" kernel_get_rows_iq4_nl" )]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2 , dequantize_iq4_nl>;
6323+ #if QK_K == 64
6324+ template [[host_name(" kernel_get_rows_iq4_xs" )]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2 , dequantize_iq4_xs>;
6325+ #else
63376326template [[host_name(" kernel_get_rows_iq4_xs" )]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6327+ #endif
63386328
63396329//
63406330// matrix-matrix multiplication
@@ -6378,7 +6368,11 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m
63786368template [[host_name(" kernel_mul_mm_iq2_s_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
63796369template [[host_name(" kernel_mul_mm_iq1_s_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
63806370template [[host_name(" kernel_mul_mm_iq4_nl_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2 , dequantize_iq4_nl>;
6371+ #if QK_K == 64
6372+ template [[host_name(" kernel_mul_mm_iq4_xs_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2 , dequantize_iq4_xs>;
6373+ #else
63816374template [[host_name(" kernel_mul_mm_iq4_xs_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6375+ #endif
63826376
63836377//
63846378// indirect matrix-matrix multiplication
@@ -6434,7 +6428,11 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel
64346428template [[host_name(" kernel_mul_mm_id_iq2_s_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
64356429template [[host_name(" kernel_mul_mm_id_iq1_s_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
64366430template [[host_name(" kernel_mul_mm_id_iq4_nl_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2 , dequantize_iq4_nl>;
6431+ #if QK_K == 64
6432+ template [[host_name(" kernel_mul_mm_id_iq4_xs_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2 , dequantize_iq4_xs>;
6433+ #else
64376434template [[host_name(" kernel_mul_mm_id_iq4_xs_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6435+ #endif
64386436
64396437//
64406438// matrix-vector multiplication
@@ -7707,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
77077705
77087706 const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
77097707
7708+ #if QK_K == 64
7709+ kernel_mul_mv_iq4_nl_f32_impl (
7710+ #else
77107711 kernel_mul_mv_iq4_xs_f32_impl (
7712+ #endif
77117713 src0[id],
77127714 (device const float *) (src1 + bid*nb11),
77137715 dst + bid*ne0,
0 commit comments