Skip to content

Commit 40091b2

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents 6834e43 + 2b6b55a commit 40091b2

File tree

10 files changed

+406
-319
lines changed

10 files changed

+406
-319
lines changed

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1728,7 +1728,6 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
17281728
ggml_cann_get_rows(ctx, dst);
17291729
break;
17301730
case GGML_OP_SET_ROWS:
1731-
std::cout << "lcg GGML_OP_SET_ROWS"<< std::endl;
17321731
ggml_cann_set_rows(ctx, dst);
17331732
break;
17341733
case GGML_OP_DUP:

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
3434
}
3535

3636
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
37+
if (!ppls) {
38+
return;
39+
}
40+
3741
for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
3842
ggml_metal_pipeline_free(it->second);
3943
}
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
467471
// use custom matrix x vector kernel
468472
switch (tsrc0) {
469473
case GGML_TYPE_F32:
474+
case GGML_TYPE_F16:
475+
case GGML_TYPE_BF16:
470476
{
471-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
472-
473-
nsg = 1;
474-
nr0 = 1;
475-
nr1 = 4;
476477
if (ne00 == 4) {
478+
nsg = 1;
477479
nr0 = 32;
480+
nr1 = 4;
478481
suffix = "_c4";
479-
}
480-
} break;
481-
case GGML_TYPE_F16:
482-
case GGML_TYPE_BF16:
483-
{
484-
nsg = 1;
485-
nr0 = 1;
486-
if (op->src[1]->type == GGML_TYPE_F32) {
487-
if (ne00 == 4) {
488-
nr0 = 32;
489-
nr1 = 4;
490-
suffix = "_c4";
491-
} else if (ne11 * ne12 < 4) {
492-
suffix = "_1row";
493-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
494-
suffix = "_l4";
495-
nr1 = ne11;
496-
} else {
497-
nr1 = 4;
498-
}
482+
} else if (ne00 % 4 == 0) {
483+
nsg = N_SG_F;
484+
nr0 = N_R0_F;
485+
nr1 = 1;
486+
smem = 32*sizeof(float)*N_R0_F;
487+
suffix = "_4";
499488
} else {
500-
nr1 = 4;
489+
nsg = N_SG_F;
490+
nr0 = N_R0_F;
491+
nr1 = 1;
492+
smem = 32*sizeof(float)*N_R0_F;
501493
}
502494
} break;
503495
case GGML_TYPE_Q4_0:
@@ -623,7 +615,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
623615
return res;
624616
}
625617

626-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
618+
ggml_metal_cv_t cv = ggml_metal_cv_init();
619+
620+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
621+
622+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
623+
624+
ggml_metal_cv_free(cv);
627625

628626
ggml_metal_pipeline_set_nr0 (res, nr0);
629627
ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -689,25 +687,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
689687
const ggml_type tsrc0 = op->src[0]->type;
690688
const ggml_type tsrc1 = op->src[1]->type;
691689

690+
const char * suffix = "";
691+
692692
// use custom matrix x vector kernel
693693
switch (tsrc0) {
694694
case GGML_TYPE_F32:
695-
{
696-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
697-
nsg = 1;
698-
nr0 = 1;
699-
} break;
700695
case GGML_TYPE_F16:
701-
{
702-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
703-
nsg = 1;
704-
nr0 = 1;
705-
} break;
706696
case GGML_TYPE_BF16:
707697
{
708-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
709-
nsg = 1;
710-
nr0 = 1;
698+
if (ne00 % 4 == 0) {
699+
nsg = N_SG_F;
700+
nr0 = N_R0_F;
701+
nr1 = 1;
702+
smem = 32*sizeof(float)*N_R0_F;
703+
suffix = "_4";
704+
} else {
705+
nsg = N_SG_F;
706+
nr0 = N_R0_F;
707+
nr1 = 1;
708+
smem = 32*sizeof(float)*N_R0_F;
709+
}
711710
} break;
712711
case GGML_TYPE_Q4_0:
713712
{
@@ -824,15 +823,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824823
}
825824
};
826825

827-
snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
826+
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
828827
snprintf(name, 256, "%s", base);
829828

830829
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
831830
if (res) {
832831
return res;
833832
}
834833

835-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
834+
ggml_metal_cv_t cv = ggml_metal_cv_init();
835+
836+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
837+
838+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
839+
840+
ggml_metal_cv_free(cv);
836841

837842
ggml_metal_pipeline_set_nr0 (res, nr0);
838843
ggml_metal_pipeline_set_nr1 (res, nr1);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ typedef struct ggml_metal_cv * ggml_metal_cv_t;
2222
ggml_metal_cv_t ggml_metal_cv_init(void);
2323
void ggml_metal_cv_free(ggml_metal_cv_t cv);
2424

25+
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
2526
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
2627
void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);
2728

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ void ggml_metal_cv_free(ggml_metal_cv_t cv) {
5151
free(cv);
5252
}
5353

54+
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {
55+
[cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];
56+
}
57+
5458
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {
5559
[cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];
5660
}
@@ -327,12 +331,19 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
327331

328332
GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
329333

330-
id<MTLFunction> mtl_function = [lib->obj newFunctionWithName:base_func constantValues:(cv ? cv->obj : nil) error:&error];
334+
id<MTLFunction> mtl_function;
335+
if (!cv) {
336+
mtl_function = [lib->obj newFunctionWithName:base_func];
337+
} else {
338+
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
339+
}
331340
if (!mtl_function) {
332341
ggml_critical_section_end();
333342

334343
GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
335-
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
344+
if (error) {
345+
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
346+
}
336347

337348
return nil;
338349
}
@@ -817,6 +828,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
817828

818829
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
819830
bool is_shared;
831+
bool owned;
820832

821833
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
822834
int n_buffers;
@@ -949,6 +961,7 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
949961
if (shared) {
950962
res->all_data = ggml_metal_host_malloc(size_aligned);
951963
res->is_shared = true;
964+
res->owned = true;
952965
} else {
953966
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below
954967
res->all_data = (void *) 0x000000400ULL;
@@ -1007,6 +1020,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
10071020
res->all_size = size;
10081021

10091022
res->is_shared = true;
1023+
res->owned = false;
10101024

10111025
res->n_buffers = 0;
10121026

@@ -1100,7 +1114,7 @@ void ggml_metal_buffer_free(ggml_metal_buffer_t buf) {
11001114

11011115
ggml_metal_buffer_rset_free(buf);
11021116

1103-
if (buf->is_shared) {
1117+
if (buf->is_shared && buf->owned) {
11041118
#if TARGET_OS_OSX
11051119
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)buf->all_data, buf->all_size);
11061120
#else

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11+
#define N_R0_F 2
12+
#define N_SG_F 4
13+
1114
#define N_R0_Q4_0 4
1215
#define N_SG_Q4_0 2
1316

@@ -72,6 +75,7 @@
7275
#define FC_FLASH_ATTN_EXT 100
7376
#define FC_FLASH_ATTN_EXT_VEC 200
7477
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
78+
#define FC_MUL_MV 400
7579

7680
// kernel argument structs
7781
//

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,10 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
15641564

15651565
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
15661566

1567-
if (op->src[0]->type == GGML_TYPE_Q8_0) {
1567+
if (op->src[0]->type == GGML_TYPE_F32 ||
1568+
op->src[0]->type == GGML_TYPE_F16 ||
1569+
op->src[0]->type == GGML_TYPE_BF16 ||
1570+
op->src[0]->type == GGML_TYPE_Q8_0) {
15681571
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
15691572
} else {
15701573
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
@@ -1772,7 +1775,10 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
17721775

17731776
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
17741777

1775-
if (op->src[0]->type == GGML_TYPE_Q8_0) {
1778+
if (op->src[0]->type == GGML_TYPE_F32 ||
1779+
op->src[0]->type == GGML_TYPE_F16 ||
1780+
op->src[0]->type == GGML_TYPE_BF16 ||
1781+
op->src[0]->type == GGML_TYPE_Q8_0) {
17761782
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
17771783
} else {
17781784
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);

0 commit comments

Comments
 (0)