@@ -7487,7 +7487,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
74877487 kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr , tgpig, tiisg, sgitg);
74887488}
74897489
7490- template <int nr0 , typename args_t >
7490+ template <int NR0 , typename args_t >
74917491void kernel_mul_mv_iq4_nl_f32_impl (
74927492 args_t args,
74937493 device const char * src0,
@@ -7500,13 +7500,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
75007500 const short NSG = FC_mul_mv_nsg;
75017501
75027502 threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7503- const int nb = args.ne00 /QK4_NL;
75047503
75057504 const int r0 = tgpig.x ;
75067505 const int r1 = tgpig.y ;
75077506 const int im = tgpig.z ;
75087507
7509- const int first_row = (r0 * NSG + sgitg) * nr0 ;
7508+ const int first_row = (r0 * NSG + sgitg) * NR0 ;
75107509
75117510 const uint i12 = im%args.ne12 ;
75127511 const uint i13 = im/args.ne12 ;
@@ -7517,31 +7516,35 @@ void kernel_mul_mv_iq4_nl_f32_impl(
75177516 device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
75187517 device const float * y = (device const float *) (src1 + offset1);
75197518
7519+ const int nb = args.ne00 /QK4_NL;
7520+ const int ns01 = args.nb01 /args.nb00 ;
7521+
75207522 const short ix = tiisg/2 ; // 0...15
75217523 const short it = tiisg%2 ; // 0 or 1
75227524
75237525 shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16 ];
75247526 threadgroup_barrier (mem_flags::mem_threadgroup);
75257527
75267528 float4 yl[4 ];
7527- float sumf[nr0 ]={0 .f };
7529+ float sumf[NR0 ]={0 .f };
75287530
7529- device const float * yb = y + ix * QK4_NL + it * 8 ;
7531+ device const float * yb = y + ix* QK4_NL + it* 8 ;
75307532
75317533 uint32_t aux32[2 ];
75327534 thread const uint8_t * q8 = (thread const uint8_t *)aux32;
75337535
75347536 float4 qf1, qf2;
75357537
7536- for (int ib = ix; ib < nb; ib += 16 ) {
7538+ // [TAG_MUL_MV_WEIRD]
7539+ for (int ib = ix; ib < nb && ib < ns01; ib += 16 ) {
75377540 device const float4 * y4 = (device const float4 *)yb;
75387541 yl[0 ] = y4[0 ];
75397542 yl[1 ] = y4[4 ];
75407543 yl[2 ] = y4[1 ];
75417544 yl[3 ] = y4[5 ];
75427545
7543- for (short row = 0 ; row < nr0 ; row++) {
7544- device const block_iq4_nl & xb = x[row*nb + ib];
7546+ for (short row = 0 ; row < NR0 ; row++) {
7547+ device const block_iq4_nl & xb = x[row*ns01 + ib];
75457548 device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8 *it);
75467549
75477550 float4 acc1 = {0 .f }, acc2 = {0 .f };
@@ -7572,7 +7575,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
75727575
75737576 device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
75747577
7575- for (int row = 0 ; row < nr0 && first_row + row < args.ne0 ; ++row) {
7578+ for (int row = 0 ; row < NR0 && first_row + row < args.ne0 ; ++row) {
75767579 float sum_all = simd_sum (sumf[row]);
75777580 if (tiisg == 0 ) {
75787581 dst_f32[first_row + row] = sum_all;
@@ -7594,7 +7597,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
75947597 kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
75957598}
75967599
7597- template <int nr0 , typename args_t >
7600+ template <int NR0 , typename args_t >
75987601void kernel_mul_mv_iq4_xs_f32_impl (
75997602 args_t args,
76007603 device const char * src0,
@@ -7607,12 +7610,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76077610 const short NSG = FC_mul_mv_nsg;
76087611
76097612 threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7610- const int nb = args.ne00 /QK_K;
76117613
76127614 const int r0 = tgpig.x ;
76137615 const int r1 = tgpig.y ;
76147616 const int im = tgpig.z ;
7615- const int first_row = (r0 * NSG + sgitg) * nr0 ;
7617+ const int first_row = (r0 * NSG + sgitg) * NR0 ;
76167618
76177619 const uint i12 = im%args.ne12 ;
76187620 const uint i13 = im/args.ne12 ;
@@ -7623,6 +7625,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76237625 device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
76247626 device const float * y = (device const float *) (src1 + offset1);
76257627
7628+ const int nb = args.ne00 /QK_K;
7629+ const int ns01 = args.nb01 /args.nb00 ;
7630+
76267631 const short ix = tiisg/16 ; // 0 or 1
76277632 const short it = tiisg%16 ; // 0...15
76287633 const short ib = it/2 ;
@@ -7632,7 +7637,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76327637 threadgroup_barrier (mem_flags::mem_threadgroup);
76337638
76347639 float4 yl[4 ];
7635- float sumf[nr0 ]={0 .f };
7640+ float sumf[NR0 ]={0 .f };
76367641
76377642 device const float * yb = y + ix * QK_K + ib * 32 + il * 8 ;
76387643
@@ -7641,15 +7646,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76417646
76427647 float4 qf1, qf2;
76437648
7644- for (int ibl = ix; ibl < nb; ibl += 2 ) {
7649+ // [TAG_MUL_MV_WEIRD]
7650+ for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2 ) {
76457651 device const float4 * y4 = (device const float4 *)yb;
76467652 yl[0 ] = y4[0 ];
76477653 yl[1 ] = y4[4 ];
76487654 yl[2 ] = y4[1 ];
76497655 yl[3 ] = y4[5 ];
76507656
7651- for (short row = 0 ; row < nr0 ; ++row) {
7652- device const block_iq4_xs & xb = x[row*nb + ibl];
7657+ for (short row = 0 ; row < NR0 ; ++row) {
7658+ device const block_iq4_xs & xb = x[row*ns01 + ibl];
76537659 device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16 *ib + 8 *il);
76547660
76557661 float4 acc1 = {0 .f }, acc2 = {0 .f };
@@ -7679,7 +7685,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76797685
76807686 device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
76817687
7682- for (int row = 0 ; row < nr0 && first_row + row < args.ne0 ; ++row) {
7688+ for (int row = 0 ; row < NR0 && first_row + row < args.ne0 ; ++row) {
76837689 float sum_all = simd_sum (sumf[row]);
76847690 if (tiisg == 0 ) {
76857691 dst_f32[first_row + row] = sum_all;
@@ -7701,7 +7707,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
77017707 kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
77027708}
77037709
7704- template <int nr0 , typename args_t >
7710+ template <int NR0 , typename args_t >
77057711void kernel_mul_mv_mxfp4_f32_impl (
77067712 args_t args,
77077713 device const char * src0,
@@ -7714,13 +7720,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
77147720 const short NSG = FC_mul_mv_nsg;
77157721
77167722 threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7717- const int nb = args.ne00 /QK_MXFP4;
77187723
77197724 const int r0 = tgpig.x ;
77207725 const int r1 = tgpig.y ;
77217726 const int im = tgpig.z ;
77227727
7723- const int first_row = (r0 * NSG + sgitg) * nr0 ;
7728+ const int first_row = (r0 * NSG + sgitg) * NR0 ;
77247729
77257730 const uint i12 = im%args.ne12 ;
77267731 const uint i13 = im/args.ne12 ;
@@ -7731,27 +7736,32 @@ void kernel_mul_mv_mxfp4_f32_impl(
77317736 device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
77327737 device const float * y = (device const float *) (src1 + offset1);
77337738
7739+ const int nb = args.ne00 /QK_MXFP4;
7740+ const int ns01 = args.nb01 /args.nb00 ; // this can be larger than nb for permuted src0 tensors
7741+
77347742 const short ix = tiisg/2 ; // 0...15
77357743 const short it = tiisg%2 ; // 0 or 1
77367744
77377745 shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16 ];
77387746 threadgroup_barrier (mem_flags::mem_threadgroup);
77397747
77407748 float4 yl[4 ];
7741- float sumf[nr0 ]={0 .f };
7749+ float sumf[NR0 ]={0 .f };
77427750
7743- device const float * yb = y + ix * QK_MXFP4 + it * 8 ;
7751+ device const float * yb = y + ix*QK_MXFP4 + it*8 ;
7752+
7753+ // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
7754+ // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
7755+ for (int ib = ix; ib < nb && ib < ns01; ib += 16 ) {
7756+ device const float4 * y4 = (device const float4 *) yb;
77447757
7745- for (int ib = ix; ib < nb; ib += 16 ) {
7746- device const float4 * y4 = (device const float4 *)yb;
77477758 yl[0 ] = y4[0 ];
77487759 yl[1 ] = y4[4 ];
77497760 yl[2 ] = y4[1 ];
77507761 yl[3 ] = y4[5 ];
77517762
7752- #pragma unroll(nr0)
7753- for (short row = 0 ; row < nr0; row++) {
7754- device const block_mxfp4 & xb = x[row*nb + ib];
7763+ FOR_UNROLL (short row = 0 ; row < NR0; row++) {
7764+ device const block_mxfp4 & xb = x[row*ns01 + ib];
77557765 device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8 *it);
77567766
77577767 float4 acc1 = yl[0 ]*float4 (shmem_f32[q2[0 ] & 0x0F ], shmem_f32[q2[1 ] & 0x0F ], shmem_f32[q2[2 ] & 0x0F ], shmem_f32[q2[3 ] & 0x0F ]);
@@ -7769,7 +7779,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
77697779
77707780 device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
77717781
7772- for (int row = 0 ; row < nr0 && first_row + row < args.ne0 ; ++row) {
7782+ for (int row = 0 ; row < NR0 && first_row + row < args.ne0 ; ++row) {
77737783 float sum_all = simd_sum (sumf[row]);
77747784 if (tiisg == 0 ) {
77757785 dst_f32[first_row + row] = sum_all;
0 commit comments