Skip to content

Commit 320f029

Browse files
committed
metal : improve F32, F16 and BF16 mat-vec multiplication
ggml-ci
1 parent e00f3fd commit 320f029

File tree

4 files changed

+226
-211
lines changed

4 files changed

+226
-211
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
3434
}
3535

3636
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
37+
if (!ppls) {
38+
return;
39+
}
40+
3741
for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
3842
ggml_metal_pipeline_free(it->second);
3943
}
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
467471
// use custom matrix x vector kernel
468472
switch (tsrc0) {
469473
case GGML_TYPE_F32:
474+
case GGML_TYPE_F16:
475+
case GGML_TYPE_BF16:
470476
{
471-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
472-
473-
nsg = 1;
474-
nr0 = 1;
475-
nr1 = 4;
476477
if (ne00 == 4) {
478+
nsg = 1;
477479
nr0 = 32;
480+
nr1 = 4;
478481
suffix = "_c4";
479-
}
480-
} break;
481-
case GGML_TYPE_F16:
482-
case GGML_TYPE_BF16:
483-
{
484-
nsg = 1;
485-
nr0 = 1;
486-
if (op->src[1]->type == GGML_TYPE_F32) {
487-
if (ne00 == 4) {
488-
nr0 = 32;
489-
nr1 = 4;
490-
suffix = "_c4";
491-
} else if (ne11 * ne12 < 4) {
492-
suffix = "_1row";
493-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
494-
suffix = "_l4";
495-
nr1 = ne11;
496-
} else {
497-
nr1 = 4;
498-
}
482+
} else if (ne00 % 4 == 0) {
483+
nsg = N_SG_F;
484+
nr0 = N_R0_F;
485+
nr1 = 1;
486+
smem = 32*sizeof(float)*N_R0_F;
487+
suffix = "_4";
499488
} else {
500-
nr1 = 4;
489+
nsg = N_SG_F;
490+
nr0 = N_R0_F;
491+
nr1 = 1;
492+
smem = 32*sizeof(float)*N_R0_F;
501493
}
502494
} break;
503495
case GGML_TYPE_Q4_0:
@@ -689,25 +681,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
689681
const ggml_type tsrc0 = op->src[0]->type;
690682
const ggml_type tsrc1 = op->src[1]->type;
691683

684+
const char * suffix = "";
685+
692686
// use custom matrix x vector kernel
693687
switch (tsrc0) {
694688
case GGML_TYPE_F32:
695-
{
696-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
697-
nsg = 1;
698-
nr0 = 1;
699-
} break;
700689
case GGML_TYPE_F16:
701-
{
702-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
703-
nsg = 1;
704-
nr0 = 1;
705-
} break;
706690
case GGML_TYPE_BF16:
707691
{
708-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
709-
nsg = 1;
710-
nr0 = 1;
692+
if (ne00 % 4 == 0) {
693+
nsg = N_SG_F;
694+
nr0 = N_R0_F;
695+
nr1 = 1;
696+
smem = 32*sizeof(float)*N_R0_F;
697+
suffix = "_4";
698+
} else {
699+
nsg = N_SG_F;
700+
nr0 = N_R0_F;
701+
nr1 = 1;
702+
smem = 32*sizeof(float)*N_R0_F;
703+
}
711704
} break;
712705
case GGML_TYPE_Q4_0:
713706
{
@@ -824,7 +817,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824817
}
825818
};
826819

827-
snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
820+
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
828821
snprintf(name, 256, "%s", base);
829822

830823
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11+
#define N_R0_F 2
12+
#define N_SG_F 4
13+
1114
#define N_R0_Q4_0 4
1215
#define N_SG_Q4_0 2
1316

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,10 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
15641564

15651565
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
15661566

1567-
if (op->src[0]->type == GGML_TYPE_Q8_0) {
1567+
if (op->src[0]->type == GGML_TYPE_F32 ||
1568+
op->src[0]->type == GGML_TYPE_F16 ||
1569+
op->src[0]->type == GGML_TYPE_BF16 ||
1570+
op->src[0]->type == GGML_TYPE_Q8_0) {
15681571
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
15691572
} else {
15701573
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
@@ -1772,7 +1775,10 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
17721775

17731776
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
17741777

1775-
if (op->src[0]->type == GGML_TYPE_Q8_0) {
1778+
if (op->src[0]->type == GGML_TYPE_F32 ||
1779+
op->src[0]->type == GGML_TYPE_F16 ||
1780+
op->src[0]->type == GGML_TYPE_BF16 ||
1781+
op->src[0]->type == GGML_TYPE_Q8_0) {
17761782
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
17771783
} else {
17781784
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);

0 commit comments

Comments
 (0)