Skip to content

Commit a3cb047

Browse files
authored
metal : fix mul-mm condition + fix mul-mv permuted kernels (ggml-org#16494)
1 parent 4a8fbe0 commit a3cb047

File tree

3 files changed

+44
-35
lines changed

3 files changed

+44
-35
lines changed

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,9 +1546,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
15461546
!ggml_is_transposed(op->src[1]) &&
15471547
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
15481548
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1549-
props_dev->has_simdgroup_mm && ne00 >= 64 &&
1550-
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
1551-
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1549+
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
1550+
//GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
15521551

15531552
// some Metal matrix data types require aligned pointers
15541553
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7487,7 +7487,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
74877487
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
74887488
}
74897489

7490-
template<int nr0, typename args_t>
7490+
template<int NR0, typename args_t>
74917491
void kernel_mul_mv_iq4_nl_f32_impl(
74927492
args_t args,
74937493
device const char * src0,
@@ -7500,13 +7500,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
75007500
const short NSG = FC_mul_mv_nsg;
75017501

75027502
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7503-
const int nb = args.ne00/QK4_NL;
75047503

75057504
const int r0 = tgpig.x;
75067505
const int r1 = tgpig.y;
75077506
const int im = tgpig.z;
75087507

7509-
const int first_row = (r0 * NSG + sgitg) * nr0;
7508+
const int first_row = (r0 * NSG + sgitg) * NR0;
75107509

75117510
const uint i12 = im%args.ne12;
75127511
const uint i13 = im/args.ne12;
@@ -7517,31 +7516,35 @@ void kernel_mul_mv_iq4_nl_f32_impl(
75177516
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
75187517
device const float * y = (device const float *) (src1 + offset1);
75197518

7519+
const int nb = args.ne00/QK4_NL;
7520+
const int ns01 = args.nb01/args.nb00;
7521+
75207522
const short ix = tiisg/2; // 0...15
75217523
const short it = tiisg%2; // 0 or 1
75227524

75237525
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
75247526
threadgroup_barrier(mem_flags::mem_threadgroup);
75257527

75267528
float4 yl[4];
7527-
float sumf[nr0]={0.f};
7529+
float sumf[NR0]={0.f};
75287530

7529-
device const float * yb = y + ix * QK4_NL + it * 8;
7531+
device const float * yb = y + ix*QK4_NL + it*8;
75307532

75317533
uint32_t aux32[2];
75327534
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
75337535

75347536
float4 qf1, qf2;
75357537

7536-
for (int ib = ix; ib < nb; ib += 16) {
7538+
// [TAG_MUL_MV_WEIRD]
7539+
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
75377540
device const float4 * y4 = (device const float4 *)yb;
75387541
yl[0] = y4[0];
75397542
yl[1] = y4[4];
75407543
yl[2] = y4[1];
75417544
yl[3] = y4[5];
75427545

7543-
for (short row = 0; row < nr0; row++) {
7544-
device const block_iq4_nl & xb = x[row*nb + ib];
7546+
for (short row = 0; row < NR0; row++) {
7547+
device const block_iq4_nl & xb = x[row*ns01 + ib];
75457548
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
75467549

75477550
float4 acc1 = {0.f}, acc2 = {0.f};
@@ -7572,7 +7575,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
75727575

75737576
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
75747577

7575-
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7578+
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
75767579
float sum_all = simd_sum(sumf[row]);
75777580
if (tiisg == 0) {
75787581
dst_f32[first_row + row] = sum_all;
@@ -7594,7 +7597,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
75947597
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
75957598
}
75967599

7597-
template<int nr0, typename args_t>
7600+
template<int NR0, typename args_t>
75987601
void kernel_mul_mv_iq4_xs_f32_impl(
75997602
args_t args,
76007603
device const char * src0,
@@ -7607,12 +7610,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76077610
const short NSG = FC_mul_mv_nsg;
76087611

76097612
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7610-
const int nb = args.ne00/QK_K;
76117613

76127614
const int r0 = tgpig.x;
76137615
const int r1 = tgpig.y;
76147616
const int im = tgpig.z;
7615-
const int first_row = (r0 * NSG + sgitg) * nr0;
7617+
const int first_row = (r0 * NSG + sgitg) * NR0;
76167618

76177619
const uint i12 = im%args.ne12;
76187620
const uint i13 = im/args.ne12;
@@ -7623,6 +7625,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76237625
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
76247626
device const float * y = (device const float *) (src1 + offset1);
76257627

7628+
const int nb = args.ne00/QK_K;
7629+
const int ns01 = args.nb01/args.nb00;
7630+
76267631
const short ix = tiisg/16; // 0 or 1
76277632
const short it = tiisg%16; // 0...15
76287633
const short ib = it/2;
@@ -7632,7 +7637,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76327637
threadgroup_barrier(mem_flags::mem_threadgroup);
76337638

76347639
float4 yl[4];
7635-
float sumf[nr0]={0.f};
7640+
float sumf[NR0]={0.f};
76367641

76377642
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
76387643

@@ -7641,15 +7646,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76417646

76427647
float4 qf1, qf2;
76437648

7644-
for (int ibl = ix; ibl < nb; ibl += 2) {
7649+
// [TAG_MUL_MV_WEIRD]
7650+
for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
76457651
device const float4 * y4 = (device const float4 *)yb;
76467652
yl[0] = y4[0];
76477653
yl[1] = y4[4];
76487654
yl[2] = y4[1];
76497655
yl[3] = y4[5];
76507656

7651-
for (short row = 0; row < nr0; ++row) {
7652-
device const block_iq4_xs & xb = x[row*nb + ibl];
7657+
for (short row = 0; row < NR0; ++row) {
7658+
device const block_iq4_xs & xb = x[row*ns01 + ibl];
76537659
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
76547660

76557661
float4 acc1 = {0.f}, acc2 = {0.f};
@@ -7679,7 +7685,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
76797685

76807686
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
76817687

7682-
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7688+
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
76837689
float sum_all = simd_sum(sumf[row]);
76847690
if (tiisg == 0) {
76857691
dst_f32[first_row + row] = sum_all;
@@ -7701,7 +7707,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
77017707
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
77027708
}
77037709

7704-
template<int nr0, typename args_t>
7710+
template<int NR0, typename args_t>
77057711
void kernel_mul_mv_mxfp4_f32_impl(
77067712
args_t args,
77077713
device const char * src0,
@@ -7714,13 +7720,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
77147720
const short NSG = FC_mul_mv_nsg;
77157721

77167722
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7717-
const int nb = args.ne00/QK_MXFP4;
77187723

77197724
const int r0 = tgpig.x;
77207725
const int r1 = tgpig.y;
77217726
const int im = tgpig.z;
77227727

7723-
const int first_row = (r0 * NSG + sgitg) * nr0;
7728+
const int first_row = (r0 * NSG + sgitg) * NR0;
77247729

77257730
const uint i12 = im%args.ne12;
77267731
const uint i13 = im/args.ne12;
@@ -7731,27 +7736,32 @@ void kernel_mul_mv_mxfp4_f32_impl(
77317736
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
77327737
device const float * y = (device const float *) (src1 + offset1);
77337738

7739+
const int nb = args.ne00/QK_MXFP4;
7740+
const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
7741+
77347742
const short ix = tiisg/2; // 0...15
77357743
const short it = tiisg%2; // 0 or 1
77367744

77377745
shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
77387746
threadgroup_barrier(mem_flags::mem_threadgroup);
77397747

77407748
float4 yl[4];
7741-
float sumf[nr0]={0.f};
7749+
float sumf[NR0]={0.f};
77427750

7743-
device const float * yb = y + ix * QK_MXFP4 + it * 8;
7751+
device const float * yb = y + ix*QK_MXFP4 + it*8;
7752+
7753+
// note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
7754+
// no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
7755+
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
7756+
device const float4 * y4 = (device const float4 *) yb;
77447757

7745-
for (int ib = ix; ib < nb; ib += 16) {
7746-
device const float4 * y4 = (device const float4 *)yb;
77477758
yl[0] = y4[0];
77487759
yl[1] = y4[4];
77497760
yl[2] = y4[1];
77507761
yl[3] = y4[5];
77517762

7752-
#pragma unroll(nr0)
7753-
for (short row = 0; row < nr0; row++) {
7754-
device const block_mxfp4 & xb = x[row*nb + ib];
7763+
FOR_UNROLL (short row = 0; row < NR0; row++) {
7764+
device const block_mxfp4 & xb = x[row*ns01 + ib];
77557765
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
77567766

77577767
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
@@ -7769,7 +7779,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
77697779

77707780
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
77717781

7772-
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7782+
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
77737783
float sum_all = simd_sum(sumf[row]);
77747784
if (tiisg == 0) {
77757785
dst_f32[first_row + row] = sum_all;

src/llama-model.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16313,10 +16313,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
1631316313
}
1631416314

1631516315
ggml_tensor * build_layer_ffn(
16316-
ggml_tensor * cur,
16317-
ggml_tensor * inpSA,
16318-
const llama_model & model,
16319-
const int il) {
16316+
ggml_tensor * cur,
16317+
ggml_tensor * inpSA,
16318+
const llama_model & model,
16319+
const int il) {
1632016320

1632116321
// For Granite architectures - scale residual
1632216322
if (hparams.f_residual_scale) {

0 commit comments

Comments
 (0)