Skip to content

Commit 7760ede

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 215df9c + 368560a commit 7760ede

File tree

12 files changed

+744
-598
lines changed

12 files changed

+744
-598
lines changed

common/arg.cpp

Lines changed: 266 additions & 202 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/cpy.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
441441
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
442442
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
443443
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
444+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
445+
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
446+
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
447+
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
444448
} else {
445449
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
446450
ggml_type_name(src0->type), ggml_type_name(src1->type));

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
3535
switch (D) {
3636
case 64:
3737
case 128:
38-
return 128;
3938
case 256:
4039
return ncols <= 16 ? 128 : 64;
4140
default:
@@ -86,7 +85,6 @@ static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols
8685
switch (D) {
8786
case 64:
8887
case 128:
89-
return 128;
9088
case 256:
9189
return ncols <= 16 ? 128 : 64;
9290
default:

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 64 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
3434
}
3535

3636
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
37+
if (!ppls) {
38+
return;
39+
}
40+
3741
for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
3842
ggml_metal_pipeline_free(it->second);
3943
}
@@ -410,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
410414
return res;
411415
}
412416

413-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
417+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
414418
char base[256];
415419
char name[256];
416420

417421
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
418-
snprintf(name, 256, "%s", base);
422+
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
419423

420424
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
421425
if (res) {
422426
return res;
423427
}
424428

425-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
429+
ggml_metal_cv_t cv = ggml_metal_cv_init();
430+
431+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
432+
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
433+
434+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
435+
436+
ggml_metal_cv_free(cv);
426437

427438
return res;
428439
}
@@ -467,37 +478,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
467478
// use custom matrix x vector kernel
468479
switch (tsrc0) {
469480
case GGML_TYPE_F32:
481+
case GGML_TYPE_F16:
482+
case GGML_TYPE_BF16:
470483
{
471-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
472-
473-
nsg = 1;
474-
nr0 = 1;
475-
nr1 = 4;
476484
if (ne00 == 4) {
485+
nsg = 1;
477486
nr0 = 32;
487+
nr1 = 4;
478488
suffix = "_c4";
479-
}
480-
} break;
481-
case GGML_TYPE_F16:
482-
case GGML_TYPE_BF16:
483-
{
484-
nsg = 1;
485-
nr0 = 1;
486-
if (op->src[1]->type == GGML_TYPE_F32) {
487-
if (ne00 == 4) {
488-
nr0 = 32;
489-
nr1 = 4;
490-
suffix = "_c4";
491-
} else if (ne11 * ne12 < 4) {
492-
suffix = "_1row";
493-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
494-
suffix = "_l4";
495-
nr1 = ne11;
496-
} else {
497-
nr1 = 4;
498-
}
489+
} else if (ne00 % 4 == 0) {
490+
nsg = N_SG_F;
491+
nr0 = N_R0_F;
492+
nr1 = 1;
493+
smem = 32*sizeof(float)*N_R0_F;
494+
suffix = "_4";
499495
} else {
500-
nr1 = 4;
496+
nsg = N_SG_F;
497+
nr0 = N_R0_F;
498+
nr1 = 1;
499+
smem = 32*sizeof(float)*N_R0_F;
501500
}
502501
} break;
503502
case GGML_TYPE_Q4_0:
@@ -616,14 +615,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
616615
};
617616

618617
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
619-
snprintf(name, 256, "%s", base);
618+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
620619

621620
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
622621
if (res) {
623622
return res;
624623
}
625624

626-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
625+
ggml_metal_cv_t cv = ggml_metal_cv_init();
626+
627+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
628+
629+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
630+
631+
ggml_metal_cv_free(cv);
627632

628633
ggml_metal_pipeline_set_nr0 (res, nr0);
629634
ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -689,25 +694,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
689694
const ggml_type tsrc0 = op->src[0]->type;
690695
const ggml_type tsrc1 = op->src[1]->type;
691696

697+
const char * suffix = "";
698+
692699
// use custom matrix x vector kernel
693700
switch (tsrc0) {
694701
case GGML_TYPE_F32:
695-
{
696-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
697-
nsg = 1;
698-
nr0 = 1;
699-
} break;
700702
case GGML_TYPE_F16:
701-
{
702-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
703-
nsg = 1;
704-
nr0 = 1;
705-
} break;
706703
case GGML_TYPE_BF16:
707704
{
708-
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
709-
nsg = 1;
710-
nr0 = 1;
705+
if (ne00 % 4 == 0) {
706+
nsg = N_SG_F;
707+
nr0 = N_R0_F;
708+
nr1 = 1;
709+
smem = 32*sizeof(float)*N_R0_F;
710+
suffix = "_4";
711+
} else {
712+
nsg = N_SG_F;
713+
nr0 = N_R0_F;
714+
nr1 = 1;
715+
smem = 32*sizeof(float)*N_R0_F;
716+
}
711717
} break;
712718
case GGML_TYPE_Q4_0:
713719
{
@@ -824,15 +830,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824830
}
825831
};
826832

827-
snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
828-
snprintf(name, 256, "%s", base);
833+
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
834+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
829835

830836
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
831837
if (res) {
832838
return res;
833839
}
834840

835-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
841+
ggml_metal_cv_t cv = ggml_metal_cv_init();
842+
843+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
844+
845+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
846+
847+
ggml_metal_cv_free(cv);
836848

837849
ggml_metal_pipeline_set_nr0 (res, nr0);
838850
ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -918,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
918930
dk,
919931
dv);
920932

921-
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
922-
"flash_attn_ext",
923-
ggml_type_name(op->src[1]->type),
924-
dk,
925-
dv,
933+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
934+
base,
926935
has_mask,
927936
has_sinks,
928937
has_bias,
@@ -980,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
980989
dk,
981990
dv);
982991

983-
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
984-
"flash_attn_ext_vec",
985-
ggml_type_name(op->src[1]->type),
986-
dk,
987-
dv,
992+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
993+
base,
988994
has_mask,
989995
has_sinks,
990996
has_bias,
@@ -1028,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
10281034
char name[256];
10291035

10301036
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
1031-
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
1037+
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
10321038

10331039
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
10341040
if (res) {

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ typedef struct ggml_metal_cv * ggml_metal_cv_t;
2222
ggml_metal_cv_t ggml_metal_cv_init(void);
2323
void ggml_metal_cv_free(ggml_metal_cv_t cv);
2424

25+
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
2526
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
2627
void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);
2728

@@ -113,7 +114,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_me
113114
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
114115
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
115116
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
116-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int r1ptg);
117+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
117118
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
118119
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
119120
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ void ggml_metal_cv_free(ggml_metal_cv_t cv) {
5151
free(cv);
5252
}
5353

54+
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {
55+
[cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];
56+
}
57+
5458
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {
5559
[cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];
5660
}
@@ -327,12 +331,19 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
327331

328332
GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
329333

330-
id<MTLFunction> mtl_function = [lib->obj newFunctionWithName:base_func constantValues:(cv ? cv->obj : nil) error:&error];
334+
id<MTLFunction> mtl_function;
335+
if (!cv) {
336+
mtl_function = [lib->obj newFunctionWithName:base_func];
337+
} else {
338+
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
339+
}
331340
if (!mtl_function) {
332341
ggml_critical_section_end();
333342

334343
GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
335-
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
344+
if (error) {
345+
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
346+
}
336347

337348
return nil;
338349
}
@@ -817,6 +828,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
817828

818829
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
819830
bool is_shared;
831+
bool owned;
820832

821833
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
822834
int n_buffers;
@@ -949,6 +961,7 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
949961
if (shared) {
950962
res->all_data = ggml_metal_host_malloc(size_aligned);
951963
res->is_shared = true;
964+
res->owned = true;
952965
} else {
953966
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below
954967
res->all_data = (void *) 0x000000400ULL;
@@ -1007,6 +1020,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
10071020
res->all_size = size;
10081021

10091022
res->is_shared = true;
1023+
res->owned = false;
10101024

10111025
res->n_buffers = 0;
10121026

@@ -1100,7 +1114,7 @@ void ggml_metal_buffer_free(ggml_metal_buffer_t buf) {
11001114

11011115
ggml_metal_buffer_rset_free(buf);
11021116

1103-
if (buf->is_shared) {
1117+
if (buf->is_shared && buf->owned) {
11041118
#if TARGET_OS_OSX
11051119
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)buf->all_data, buf->all_size);
11061120
#else

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11+
#define N_R0_F 2
12+
#define N_SG_F 4
13+
1114
#define N_R0_Q4_0 4
1215
#define N_SG_Q4_0 2
1316

@@ -32,13 +35,13 @@
3235
#define N_R0_Q3_K 2
3336
#define N_SG_Q3_K 2
3437

35-
#define N_R0_Q4_K 4
38+
#define N_R0_Q4_K 2
3639
#define N_SG_Q4_K 2
3740

3841
#define N_R0_Q5_K 2
3942
#define N_SG_Q5_K 2
4043

41-
#define N_R0_Q6_K 1
44+
#define N_R0_Q6_K 2
4245
#define N_SG_Q6_K 2
4346

4447
#define N_R0_IQ1_S 4
@@ -72,6 +75,7 @@
7275
#define FC_FLASH_ATTN_EXT 100
7376
#define FC_FLASH_ATTN_EXT_VEC 200
7477
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
78+
#define FC_MUL_MV 400
7579

7680
// kernel argument structs
7781
//
@@ -370,9 +374,6 @@ typedef struct {
370374
int32_t ne1;
371375
int16_t r2;
372376
int16_t r3;
373-
int16_t nsg;
374-
int16_t nxpsg;
375-
int16_t r1ptg;
376377
} ggml_metal_kargs_mul_mv_ext;
377378

378379
typedef struct {

0 commit comments

Comments
 (0)