Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,21 +438,35 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_libr
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];

const ggml_type tsrc0 = op->src[0]->type;
const ggml_type tsrc1 = op->src[1]->type;

const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;

snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_pipeline_set_smem(res, 8192);
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);

// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);

return res;
}
Expand Down Expand Up @@ -659,19 +673,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];

const ggml_type tsrc0 = op->src[0]->type;
const ggml_type tsrc1 = op->src[1]->type;

const bool bc_inp = op->src[0]->ne[0] % 32 != 0;

snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_bci=%d", base, bc_inp);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);

ggml_metal_pipeline_set_smem(res, 8192);

Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
3 changes: 1 addition & 2 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction &&
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
return has_simdgroup_reduction;
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400
#define FC_MUL_MM 500

// kernel argument structs
//
Expand Down
40 changes: 16 additions & 24 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1476,22 +1476,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
!ggml_is_transposed(op->src[1]) &&
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
props_dev->has_simdgroup_mm &&
op->src[1]->type == GGML_TYPE_F32 &&
ne00 % 32 == 0 && ne00 >= 64 &&
props_dev->has_simdgroup_mm && ne00 >= 64 &&
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);

// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (op->src[0]->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
//switch (op->src[0]->type) {
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
// default: break;
//}

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);

ggml_metal_kargs_mul_mm args = {
/*.ne00 =*/ ne00,
Expand Down Expand Up @@ -1612,8 +1610,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(!ggml_is_transposed(op->src[0]));
GGML_ASSERT(!ggml_is_transposed(op->src[1]));

GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);

GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);

Expand All @@ -1631,19 +1627,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;

if (props_dev->has_simdgroup_mm &&
ne00 % 32 == 0 && ne00 >= 64 &&
(ne21 >= ne21_mm_id_min)) {
GGML_ASSERT(ne00 % 4 == 0);

if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (op->src[0]->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
//switch (op->src[0]->type) {
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
// default: break;
//}

// extra buffers for intermediate id mapping
ggml_metal_buffer_id bid_tpe = bid_dst;
Expand Down Expand Up @@ -1687,7 +1679,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx);

{
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, GGML_TYPE_F16);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);

ggml_metal_kargs_mul_mm_id args = {
/*.ne00 =*/ ne00,
Expand Down
Loading
Loading