Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
if (ne00 == 4) {
if (ne00 < 32) {
nsg = 1;
nr0 = 32;
nr1 = 4;
suffix = "_c4";
} else if (ne00 % 4 == 0) {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
suffix = "_4";
suffix = "_short";
} else {
nsg = N_SG_F;
nr0 = N_R0_F;
nsg = std::min(4, (ne00 + 127) / 128);
nr0 = 2;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
smem = 32*sizeof(float)*nr0;
suffix = ne00 % 4 == 0 ? "_4" : "";
}
} break;
case GGML_TYPE_Q4_0:
Expand Down Expand Up @@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
if (ne00 % 4 == 0) {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
suffix = "_4";
} else {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
}
nsg = std::min(4, (ne00 + 127) / 128);
nr0 = 2;
nr1 = 1;
smem = 32*sizeof(float)*nr0;
suffix = ne00 % 4 == 0 ? "_4" : "";
} break;
case GGML_TYPE_Q4_0:
{
Expand Down
5 changes: 2 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
//
// TODO: for optimal performance, become function of the device and work size

#define N_R0_F 2
#define N_SG_F 4

#define N_R0_Q4_0 4
#define N_SG_Q4_0 2

Expand Down Expand Up @@ -352,6 +349,7 @@ typedef struct {
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t nr0;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mv;
Expand Down Expand Up @@ -427,6 +425,7 @@ typedef struct {
int32_t ne0;
int32_t ne1;
uint64_t nb1;
int32_t nr0;
} ggml_metal_kargs_mul_mv_id;

// NORM
Expand Down
30 changes: 16 additions & 14 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
} else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);

const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);

const size_t smem = ggml_metal_pipeline_get_smem(pipeline);

ggml_metal_kargs_mul_mv args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
Expand All @@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nr0 =*/ nr0,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};

const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);

const size_t smem = ggml_metal_pipeline_get_smem(pipeline);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
Expand Down Expand Up @@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
}
} else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);

const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);

const size_t smem = ggml_metal_pipeline_get_smem(pipeline);

ggml_metal_kargs_mul_mv_id args = {
/*.nei0 =*/ ne20,
/*.nei1 =*/ ne21,
Expand All @@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nb1 =*/ nb1,
/*.nr0 =*/ nr0,
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);

const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);

const size_t smem = ggml_metal_pipeline_get_smem(pipeline);

if (ggml_is_quantized(op->src[0]->type)) {
GGML_ASSERT(ne00 >= nsg*nr0);
}
Expand Down
Loading