Skip to content

Commit 703f9e3

Browse files
authored
metal : use function constants for mul_mv_ext kernels (#16074)
* metal : use function constants for mul_mv_ext kernels ggml-ci * metal : remove NW template argument ggml-ci * metal : adjust constants ggml-ci
1 parent ad6bd90 commit 703f9e3

File tree

5 files changed

+158
-166
lines changed

5 files changed

+158
-166
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -414,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
414414
return res;
415415
}
416416

417-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
417+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
418418
char base[256];
419419
char name[256];
420420

421421
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
422-
snprintf(name, 256, "%s", base);
422+
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
423423

424424
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
425425
if (res) {
426426
return res;
427427
}
428428

429-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
429+
ggml_metal_cv_t cv = ggml_metal_cv_init();
430+
431+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
432+
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
433+
434+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
435+
436+
ggml_metal_cv_free(cv);
430437

431438
return res;
432439
}
@@ -608,7 +615,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
608615
};
609616

610617
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
611-
snprintf(name, 256, "%s", base);
618+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
612619

613620
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
614621
if (res) {
@@ -824,7 +831,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824831
};
825832

826833
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
827-
snprintf(name, 256, "%s", base);
834+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
828835

829836
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
830837
if (res) {
@@ -923,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
923930
dk,
924931
dv);
925932

926-
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
927-
"flash_attn_ext",
928-
ggml_type_name(op->src[1]->type),
929-
dk,
930-
dv,
933+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
934+
base,
931935
has_mask,
932936
has_sinks,
933937
has_bias,
@@ -985,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
985989
dk,
986990
dv);
987991

988-
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
989-
"flash_attn_ext_vec",
990-
ggml_type_name(op->src[1]->type),
991-
dk,
992-
dv,
992+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
993+
base,
993994
has_mask,
994995
has_sinks,
995996
has_bias,
@@ -1033,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
10331034
char name[256];
10341035

10351036
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
1036-
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
1037+
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
10371038

10381039
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
10391040
if (res) {

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_me
114114
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
115115
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
116116
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
117-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int r1ptg);
117+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
118118
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
119119
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
120120
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@
3535
#define N_R0_Q3_K 2
3636
#define N_SG_Q3_K 2
3737

38-
#define N_R0_Q4_K 4
38+
#define N_R0_Q4_K 2
3939
#define N_SG_Q4_K 2
4040

4141
#define N_R0_Q5_K 2
4242
#define N_SG_Q5_K 2
4343

44-
#define N_R0_Q6_K 1
44+
#define N_R0_Q6_K 2
4545
#define N_SG_Q6_K 2
4646

4747
#define N_R0_IQ1_S 4
@@ -374,9 +374,6 @@ typedef struct {
374374
int32_t ne1;
375375
int16_t r2;
376376
int16_t r3;
377-
int16_t nsg;
378-
int16_t nxpsg;
379-
int16_t r1ptg;
380377
} ggml_metal_kargs_mul_mv_ext;
381378

382379
typedef struct {

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
14441444
GGML_ABORT("unsupported ne11");
14451445
};
14461446

1447-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, r1ptg);
1447+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
14481448

14491449
ggml_metal_kargs_mul_mv_ext args = {
14501450
/*.ne00 =*/ ne00,
@@ -1465,9 +1465,6 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
14651465
/*.ne1 =*/ ne1,
14661466
/*.r2 =*/ r2,
14671467
/*.r3 =*/ r3,
1468-
/*.nsg =*/ nsg,
1469-
/*.nxpsg =*/ nxpsg,
1470-
/*.r1ptg =*/ r1ptg,
14711468
};
14721469

14731470
ggml_metal_encoder_set_pipeline(enc, pipeline);

0 commit comments

Comments
 (0)