Skip to content

Commit b213fce

Browse files
authored
metal : improve F32, F16 and BF16 mat-vec multiplication (#16057)
* metal : improve F32, F16 and BF16 mat-vec multiplication ggml-ci * metal : make the NSG a function constant in mul_mv kernels ggml-ci
1 parent e00f3fd commit b213fce

File tree

6 files changed

+355
-288
lines changed

6 files changed

+355
-288
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
3434
}
3535

3636
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
37+
if (!ppls) {
38+
return;
39+
}
40+
3741
for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
3842
ggml_metal_pipeline_free(it->second);
3943
}
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
467471
// use custom matrix x vector kernel
468472
switch (tsrc0) {
469473
case GGML_TYPE_F32:
474+
case GGML_TYPE_F16:
475+
case GGML_TYPE_BF16:
470476
{
471-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
472-
473-
nsg = 1;
474-
nr0 = 1;
475-
nr1 = 4;
476477
if (ne00 == 4) {
478+
nsg = 1;
477479
nr0 = 32;
480+
nr1 = 4;
478481
suffix = "_c4";
479-
}
480-
} break;
481-
case GGML_TYPE_F16:
482-
case GGML_TYPE_BF16:
483-
{
484-
nsg = 1;
485-
nr0 = 1;
486-
if (op->src[1]->type == GGML_TYPE_F32) {
487-
if (ne00 == 4) {
488-
nr0 = 32;
489-
nr1 = 4;
490-
suffix = "_c4";
491-
} else if (ne11 * ne12 < 4) {
492-
suffix = "_1row";
493-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
494-
suffix = "_l4";
495-
nr1 = ne11;
496-
} else {
497-
nr1 = 4;
498-
}
482+
} else if (ne00 % 4 == 0) {
483+
nsg = N_SG_F;
484+
nr0 = N_R0_F;
485+
nr1 = 1;
486+
smem = 32*sizeof(float)*N_R0_F;
487+
suffix = "_4";
499488
} else {
500-
nr1 = 4;
489+
nsg = N_SG_F;
490+
nr0 = N_R0_F;
491+
nr1 = 1;
492+
smem = 32*sizeof(float)*N_R0_F;
501493
}
502494
} break;
503495
case GGML_TYPE_Q4_0:
@@ -623,7 +615,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
623615
return res;
624616
}
625617

626-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
618+
ggml_metal_cv_t cv = ggml_metal_cv_init();
619+
620+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
621+
622+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
623+
624+
ggml_metal_cv_free(cv);
627625

628626
ggml_metal_pipeline_set_nr0 (res, nr0);
629627
ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -689,25 +687,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
689687
const ggml_type tsrc0 = op->src[0]->type;
690688
const ggml_type tsrc1 = op->src[1]->type;
691689

690+
const char * suffix = "";
691+
692692
// use custom matrix x vector kernel
693693
switch (tsrc0) {
694694
case GGML_TYPE_F32:
695-
{
696-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
697-
nsg = 1;
698-
nr0 = 1;
699-
} break;
700695
case GGML_TYPE_F16:
701-
{
702-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
703-
nsg = 1;
704-
nr0 = 1;
705-
} break;
706696
case GGML_TYPE_BF16:
707697
{
708-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
709-
nsg = 1;
710-
nr0 = 1;
698+
if (ne00 % 4 == 0) {
699+
nsg = N_SG_F;
700+
nr0 = N_R0_F;
701+
nr1 = 1;
702+
smem = 32*sizeof(float)*N_R0_F;
703+
suffix = "_4";
704+
} else {
705+
nsg = N_SG_F;
706+
nr0 = N_R0_F;
707+
nr1 = 1;
708+
smem = 32*sizeof(float)*N_R0_F;
709+
}
711710
} break;
712711
case GGML_TYPE_Q4_0:
713712
{
@@ -824,15 +823,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824823
}
825824
};
826825

827-
snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
826+
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
828827
snprintf(name, 256, "%s", base);
829828

830829
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
831830
if (res) {
832831
return res;
833832
}
834833

835-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
834+
ggml_metal_cv_t cv = ggml_metal_cv_init();
835+
836+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
837+
838+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
839+
840+
ggml_metal_cv_free(cv);
836841

837842
ggml_metal_pipeline_set_nr0 (res, nr0);
838843
ggml_metal_pipeline_set_nr1 (res, nr1);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ typedef struct ggml_metal_cv * ggml_metal_cv_t;
2222
ggml_metal_cv_t ggml_metal_cv_init(void);
2323
void ggml_metal_cv_free(ggml_metal_cv_t cv);
2424

25+
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
2526
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
2627
void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);
2728

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ void ggml_metal_cv_free(ggml_metal_cv_t cv) {
5151
free(cv);
5252
}
5353

54+
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {
55+
[cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];
56+
}
57+
5458
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {
5559
[cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];
5660
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11+
#define N_R0_F 2
12+
#define N_SG_F 4
13+
1114
#define N_R0_Q4_0 4
1215
#define N_SG_Q4_0 2
1316

@@ -72,6 +75,7 @@
7275
#define FC_FLASH_ATTN_EXT 100
7376
#define FC_FLASH_ATTN_EXT_VEC 200
7477
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
78+
#define FC_MUL_MV 400
7579

7680
// kernel argument structs
7781
//

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,10 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
15641564

15651565
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
15661566

1567-
if (op->src[0]->type == GGML_TYPE_Q8_0) {
1567+
if (op->src[0]->type == GGML_TYPE_F32 ||
1568+
op->src[0]->type == GGML_TYPE_F16 ||
1569+
op->src[0]->type == GGML_TYPE_BF16 ||
1570+
op->src[0]->type == GGML_TYPE_Q8_0) {
15681571
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
15691572
} else {
15701573
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
@@ -1772,7 +1775,10 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
17721775

17731776
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
17741777

1775-
if (op->src[0]->type == GGML_TYPE_Q8_0) {
1778+
if (op->src[0]->type == GGML_TYPE_F32 ||
1779+
op->src[0]->type == GGML_TYPE_F16 ||
1780+
op->src[0]->type == GGML_TYPE_BF16 ||
1781+
op->src[0]->type == GGML_TYPE_Q8_0) {
17761782
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
17771783
} else {
17781784
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);

0 commit comments

Comments
 (0)