Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 47 additions & 42 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
}

void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
if (!ppls) {
return;
}

for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
ggml_metal_pipeline_free(it->second);
}
Expand Down Expand Up @@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
// use custom matrix x vector kernel
switch (tsrc0) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);

nsg = 1;
nr0 = 1;
nr1 = 4;
if (ne00 == 4) {
nsg = 1;
nr0 = 32;
nr1 = 4;
suffix = "_c4";
}
} break;
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
nsg = 1;
nr0 = 1;
if (op->src[1]->type == GGML_TYPE_F32) {
if (ne00 == 4) {
nr0 = 32;
nr1 = 4;
suffix = "_c4";
} else if (ne11 * ne12 < 4) {
suffix = "_1row";
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
suffix = "_l4";
nr1 = ne11;
} else {
nr1 = 4;
}
} else if (ne00 % 4 == 0) {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
suffix = "_4";
} else {
nr1 = 4;
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
}
} break;
case GGML_TYPE_Q4_0:
Expand Down Expand Up @@ -623,7 +615,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);

ggml_metal_pipeline_set_nr0 (res, nr0);
ggml_metal_pipeline_set_nr1 (res, nr1);
Expand Down Expand Up @@ -689,25 +687,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
const ggml_type tsrc0 = op->src[0]->type;
const ggml_type tsrc1 = op->src[1]->type;

const char * suffix = "";

// use custom matrix x vector kernel
switch (tsrc0) {
case GGML_TYPE_F32:
{
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
nsg = 1;
nr0 = 1;
} break;
case GGML_TYPE_F16:
{
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
nsg = 1;
nr0 = 1;
} break;
case GGML_TYPE_BF16:
{
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
nsg = 1;
nr0 = 1;
if (ne00 % 4 == 0) {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
suffix = "_4";
} else {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
}
} break;
case GGML_TYPE_Q4_0:
{
Expand Down Expand Up @@ -824,15 +823,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
}
};

snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);

ggml_metal_pipeline_set_nr0 (res, nr0);
ggml_metal_pipeline_set_nr1 (res, nr1);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ typedef struct ggml_metal_cv * ggml_metal_cv_t;
ggml_metal_cv_t ggml_metal_cv_init(void);
void ggml_metal_cv_free(ggml_metal_cv_t cv);

void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ void ggml_metal_cv_free(ggml_metal_cv_t cv) {
free(cv);
}

void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {
[cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];
}

void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {
[cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];
}
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
//
// TODO: for optimal performance, become function of the device and work size

#define N_R0_F 2
#define N_SG_F 4

#define N_R0_Q4_0 4
#define N_SG_Q4_0 2

Expand Down Expand Up @@ -72,6 +75,7 @@
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400

// kernel argument structs
//
Expand Down
10 changes: 8 additions & 2 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,10 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {

ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

if (op->src[0]->type == GGML_TYPE_Q8_0) {
if (op->src[0]->type == GGML_TYPE_F32 ||
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16 ||
op->src[0]->type == GGML_TYPE_Q8_0) {
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
} else {
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
Expand Down Expand Up @@ -1772,7 +1775,10 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {

ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

if (op->src[0]->type == GGML_TYPE_Q8_0) {
if (op->src[0]->type == GGML_TYPE_F32 ||
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16 ||
op->src[0]->type == GGML_TYPE_Q8_0) {
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
} else {
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
Expand Down
Loading
Loading