Skip to content

Commit 72c5703

Browse files
committed
metal : use function constants for mul_mv_ext kernels
ggml-ci
1 parent b213fce commit 72c5703

File tree

5 files changed

+33
-41
lines changed

5 files changed

+33
-41
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -414,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
414414
return res;
415415
}
416416

417-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
417+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
418418
char base[256];
419419
char name[256];
420420

421421
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
422-
snprintf(name, 256, "%s", base);
422+
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
423423

424424
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
425425
if (res) {
426426
return res;
427427
}
428428

429-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
429+
ggml_metal_cv_t cv = ggml_metal_cv_init();
430+
431+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
432+
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
433+
434+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
435+
436+
ggml_metal_cv_free(cv);
430437

431438
return res;
432439
}
@@ -608,7 +615,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
608615
};
609616

610617
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
611-
snprintf(name, 256, "%s", base);
618+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
612619

613620
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
614621
if (res) {
@@ -824,7 +831,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824831
};
825832

826833
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
827-
snprintf(name, 256, "%s", base);
834+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
828835

829836
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
830837
if (res) {
@@ -923,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
923930
dk,
924931
dv);
925932

926-
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
927-
"flash_attn_ext",
928-
ggml_type_name(op->src[1]->type),
929-
dk,
930-
dv,
933+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
934+
base,
931935
has_mask,
932936
has_sinks,
933937
has_bias,
@@ -985,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
985989
dk,
986990
dv);
987991

988-
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
989-
"flash_attn_ext_vec",
990-
ggml_type_name(op->src[1]->type),
991-
dk,
992-
dv,
992+
snprintf(name, 256, "%s%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
993+
base,
993994
has_mask,
994995
has_sinks,
995996
has_bias,
@@ -1033,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
10331034
char name[256];
10341035

10351036
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
1036-
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
1037+
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
10371038

10381039
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
10391040
if (res) {

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_me
114114
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
115115
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
116116
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
117-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int r1ptg);
117+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
118118
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
119119
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
120120
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,6 @@ typedef struct {
374374
int32_t ne1;
375375
int16_t r2;
376376
int16_t r3;
377-
int16_t nsg;
378-
int16_t nxpsg;
379-
int16_t r1ptg;
380377
} ggml_metal_kargs_mul_mv_ext;
381378

382379
typedef struct {

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
14441444
GGML_ABORT("unsupported ne11");
14451445
};
14461446

1447-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, r1ptg);
1447+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
14481448

14491449
ggml_metal_kargs_mul_mv_ext args = {
14501450
/*.ne00 =*/ ne00,
@@ -1465,9 +1465,6 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
14651465
/*.ne1 =*/ ne1,
14661466
/*.r2 =*/ r2,
14671467
/*.r3 =*/ r3,
1468-
/*.nsg =*/ nsg,
1469-
/*.nxpsg =*/ nxpsg,
1470-
/*.r1ptg =*/ r1ptg,
14711468
};
14721469

14731470
ggml_metal_encoder_set_pipeline(enc, pipeline);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2883,7 +2883,8 @@ static inline void helper_mv_reduce_and_write(
28832883
}
28842884
}
28852885

2886-
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
2886+
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
2887+
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
28872888

28882889
template<typename block_q_type, short NR0, short NW, typename args_t>
28892890
void mul_vec_q_n_f32_impl(
@@ -3108,7 +3109,7 @@ kernel void kernel_mul_mv_q8_0_f32(
31083109

31093110
// mat-vec kernel processing in chunks of float4
31103111
// chpb - chunks per quantization block
3111-
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
3112+
template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
31123113
void kernel_mul_mv_ext_q4_f32_impl(
31133114
constant ggml_metal_kargs_mul_mv_ext & args,
31143115
device const char * src0,
@@ -3117,6 +3118,9 @@ void kernel_mul_mv_ext_q4_f32_impl(
31173118
uint3 tgpig[[threadgroup_position_in_grid]],
31183119
ushort tiisg[[thread_index_in_simdgroup]],
31193120
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3121+
const short NSG = FC_mul_mv_nsg;
3122+
const short nxpsg = FC_mul_mv_nxpsg;
3123+
31203124
const short chpt = 4; // chunks per thread
31213125

31223126
//const short nxpsg = (32);
@@ -3125,7 +3129,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
31253129
const short tx = tiisg%nxpsg;
31263130
const short ty = tiisg/nxpsg;
31273131

3128-
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
3132+
const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
31293133
const int i11 = tgpig.y*r1ptg;
31303134
const int i1m = tgpig.z;
31313135

@@ -3208,7 +3212,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
32083212
}
32093213

32103214
// mat-vec kernel processing in chunks of float4x4
3211-
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
3215+
template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
32123216
void kernel_mul_mv_ext_q4x4_f32_impl(
32133217
constant ggml_metal_kargs_mul_mv_ext & args,
32143218
device const char * src0,
@@ -3217,6 +3221,9 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
32173221
uint3 tgpig[[threadgroup_position_in_grid]],
32183222
ushort tiisg[[thread_index_in_simdgroup]],
32193223
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3224+
const short NSG = FC_mul_mv_nsg;
3225+
const short nxpsg = FC_mul_mv_nxpsg;
3226+
32203227
const short chpt = 1;
32213228

32223229
//const short nxpsg = (32);
@@ -3225,7 +3232,7 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
32253232
const short tx = tiisg%nxpsg;
32263233
const short ty = tiisg/nxpsg;
32273234

3228-
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
3235+
const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
32293236
const int i11 = tgpig.y*r1ptg;
32303237
const int i1m = tgpig.z;
32313238

@@ -3322,12 +3329,7 @@ kernel void kernel_mul_mv_ext_q4_f32_disp(
33223329
uint3 tgpig[[threadgroup_position_in_grid]],
33233330
ushort tiisg[[thread_index_in_simdgroup]],
33243331
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3325-
switch (args.nxpsg) {
3326-
case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
3327-
case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
3328-
case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
3329-
case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
3330-
}
3332+
kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
33313333
}
33323334

33333335
template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
@@ -3339,12 +3341,7 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
33393341
uint3 tgpig[[threadgroup_position_in_grid]],
33403342
ushort tiisg[[thread_index_in_simdgroup]],
33413343
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3342-
switch (args.nxpsg) {
3343-
case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
3344-
case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
3345-
case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
3346-
case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
3347-
}
3344+
kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
33483345
}
33493346

33503347
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;

0 commit comments

Comments
 (0)