Skip to content

Commit 35fb824

Browse files
authored
metal : dynamic simdgroups for MV kernels (ggml-org#16340)
* metal : dynamic simdgroups for MV kernels * cont : minor
1 parent 3c62aed commit 35fb824

File tree

4 files changed

+119
-96
lines changed

4 files changed

+119
-96
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
495495
case GGML_TYPE_F16:
496496
case GGML_TYPE_BF16:
497497
{
498-
if (ne00 == 4) {
498+
if (ne00 < 32) {
499499
nsg = 1;
500500
nr0 = 32;
501-
nr1 = 4;
502-
suffix = "_c4";
503-
} else if (ne00 % 4 == 0) {
504-
nsg = N_SG_F;
505-
nr0 = N_R0_F;
506501
nr1 = 1;
507-
smem = 32*sizeof(float)*N_R0_F;
508-
suffix = "_4";
502+
suffix = "_short";
509503
} else {
510-
nsg = N_SG_F;
511-
nr0 = N_R0_F;
504+
nsg = std::min(4, (ne00 + 127) / 128);
505+
nr0 = 2;
512506
nr1 = 1;
513-
smem = 32*sizeof(float)*N_R0_F;
507+
smem = 32*sizeof(float)*nr0;
508+
suffix = ne00 % 4 == 0 ? "_4" : "";
514509
}
515510
} break;
516511
case GGML_TYPE_Q4_0:
@@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
727722
case GGML_TYPE_F16:
728723
case GGML_TYPE_BF16:
729724
{
730-
if (ne00 % 4 == 0) {
731-
nsg = N_SG_F;
732-
nr0 = N_R0_F;
733-
nr1 = 1;
734-
smem = 32*sizeof(float)*N_R0_F;
735-
suffix = "_4";
736-
} else {
737-
nsg = N_SG_F;
738-
nr0 = N_R0_F;
739-
nr1 = 1;
740-
smem = 32*sizeof(float)*N_R0_F;
741-
}
725+
nsg = std::min(4, (ne00 + 127) / 128);
726+
nr0 = 2;
727+
nr1 = 1;
728+
smem = 32*sizeof(float)*nr0;
729+
suffix = ne00 % 4 == 0 ? "_4" : "";
742730
} break;
743731
case GGML_TYPE_Q4_0:
744732
{

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11-
#define N_R0_F 2
12-
#define N_SG_F 4
13-
1411
#define N_R0_Q4_0 4
1512
#define N_SG_Q4_0 2
1613

@@ -352,6 +349,7 @@ typedef struct {
352349
uint64_t nb13;
353350
int32_t ne0;
354351
int32_t ne1;
352+
int32_t nr0;
355353
int16_t r2;
356354
int16_t r3;
357355
} ggml_metal_kargs_mul_mv;
@@ -427,6 +425,7 @@ typedef struct {
427425
int32_t ne0;
428426
int32_t ne1;
429427
uint64_t nb1;
428+
int32_t nr0;
430429
} ggml_metal_kargs_mul_mv_id;
431430

432431
// NORM

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
15651565
} else {
15661566
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
15671567

1568+
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
1569+
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
1570+
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
1571+
1572+
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
1573+
15681574
ggml_metal_kargs_mul_mv args = {
15691575
/*.ne00 =*/ ne00,
15701576
/*.ne01 =*/ ne01,
@@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
15821588
/*.nb13 =*/ nb13,
15831589
/*.ne0 =*/ ne0,
15841590
/*.ne1 =*/ ne1,
1591+
/*.nr0 =*/ nr0,
15851592
/*.r2 =*/ r2,
15861593
/*.r3 =*/ r3,
15871594
};
15881595

1589-
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
1590-
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
1591-
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
1592-
1593-
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
1594-
15951596
ggml_metal_encoder_set_pipeline(enc, pipeline);
15961597
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
15971598
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
17581759
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
17591760
}
17601761
} else {
1762+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
1763+
1764+
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
1765+
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
1766+
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
1767+
1768+
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
1769+
17611770
ggml_metal_kargs_mul_mv_id args = {
17621771
/*.nei0 =*/ ne20,
17631772
/*.nei1 =*/ ne21,
@@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
17781787
/*.ne0 =*/ ne0,
17791788
/*.ne1 =*/ ne1,
17801789
/*.nb1 =*/ nb1,
1790+
/*.nr0 =*/ nr0,
17811791
};
17821792

1783-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
1784-
1785-
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
1786-
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
1787-
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
1788-
1789-
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
1790-
17911793
if (ggml_is_quantized(op->src[0]->type)) {
17921794
GGML_ASSERT(ne00 >= nsg*nr0);
17931795
}

0 commit comments

Comments
 (0)