Skip to content

Commit dba23c7

Browse files
author
alexju
committed
metal : refactor soft_max parameters into a struct
1 parent cd3dcdb commit dba23c7

File tree

3 files changed

+53
-54
lines changed

3 files changed

+53
-54
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,4 +330,15 @@ typedef struct {
330330
uint64_t nb3;
331331
} ggml_metal_kargs_sum_rows;
332332

333+
typedef struct {
334+
int64_t ne00;
335+
int64_t ne01;
336+
int64_t ne02;
337+
float scale;
338+
float max_bias;
339+
float m0;
340+
float m1;
341+
uint32_t n_head_log2;
342+
} ggml_metal_kargs_soft_max;
343+
333344
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2024,8 +2024,17 @@ static void ggml_metal_encode_node(
20242024
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
20252025
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
20262026

2027-
// TODO: add ggml_metal_kargs struct
2028-
// TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
2027+
ggml_metal_kargs_soft_max args = {
2028+
/*.ne00 =*/ ne00,
2029+
/*.ne01 =*/ ne01,
2030+
/*.ne02 =*/ ne02,
2031+
/*.scale =*/ scale,
2032+
/*.max_bias =*/ max_bias,
2033+
/*.m0 =*/ m0,
2034+
/*.m1 =*/ m1,
2035+
/*.n_head_log2 =*/ n_head_log2,
2036+
};
2037+
20292038
[encoder setComputePipelineState:pipeline];
20302039
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
20312040
if (id_src1) {
@@ -2034,14 +2043,7 @@ static void ggml_metal_encode_node(
20342043
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
20352044
}
20362045
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2037-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
2038-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
2039-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
2040-
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
2041-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
2042-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
2043-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
2044-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
2046+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
20452047

20462048
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
20472049

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -975,45 +975,38 @@ kernel void kernel_soft_max(
975975
device const char * src0,
976976
device const char * src1,
977977
device char * dst,
978-
constant int64_t & ne00,
979-
constant int64_t & ne01,
980-
constant int64_t & ne02,
981-
constant float & scale,
982-
constant float & max_bias,
983-
constant float & m0,
984-
constant float & m1,
985-
constant uint32_t & n_head_log2,
978+
constant ggml_metal_kargs_soft_max & args,
986979
threadgroup float * buf [[threadgroup(0)]],
987980
uint tgpig[[threadgroup_position_in_grid]],
988981
uint tpitg[[thread_position_in_threadgroup]],
989982
uint sgitg[[simdgroup_index_in_threadgroup]],
990983
uint tiisg[[thread_index_in_simdgroup]],
991984
uint ntg[[threads_per_threadgroup]]) {
992-
const int64_t i03 = (tgpig) / (ne02*ne01);
993-
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
994-
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
985+
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
986+
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
987+
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
995988

996-
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
997-
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
998-
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
989+
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
990+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
991+
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
999992

1000993
float slope = 1.0f;
1001994

1002995
// ALiBi
1003-
if (max_bias > 0.0f) {
996+
if (args.max_bias > 0.0f) {
1004997
const int64_t h = i02;
1005998

1006-
const float base = h < n_head_log2 ? m0 : m1;
1007-
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
999+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1000+
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
10081001

10091002
slope = pow(base, exp);
10101003
}
10111004

10121005
// parallel max
10131006
float lmax = -INFINITY;
10141007

1015-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1016-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
1008+
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1009+
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
10171010
}
10181011

10191012
// find the max value in the block
@@ -1037,8 +1030,8 @@ kernel void kernel_soft_max(
10371030

10381031
// parallel sum
10391032
float lsum = 0.0f;
1040-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1041-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1033+
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1034+
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
10421035
lsum += exp_psrc0;
10431036
pdst[i00] = exp_psrc0;
10441037
}
@@ -1068,7 +1061,7 @@ kernel void kernel_soft_max(
10681061

10691062
const float inv_sum = 1.0f/sum;
10701063

1071-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1064+
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
10721065
pdst[i00] *= inv_sum;
10731066
}
10741067
}
@@ -1078,44 +1071,37 @@ kernel void kernel_soft_max_4(
10781071
device const char * src0,
10791072
device const char * src1,
10801073
device char * dst,
1081-
constant int64_t & ne00,
1082-
constant int64_t & ne01,
1083-
constant int64_t & ne02,
1084-
constant float & scale,
1085-
constant float & max_bias,
1086-
constant float & m0,
1087-
constant float & m1,
1088-
constant uint32_t & n_head_log2,
1074+
constant ggml_metal_kargs_soft_max & args,
10891075
threadgroup float * buf [[threadgroup(0)]],
10901076
uint tgpig[[threadgroup_position_in_grid]],
10911077
uint tpitg[[thread_position_in_threadgroup]],
10921078
uint sgitg[[simdgroup_index_in_threadgroup]],
10931079
uint tiisg[[thread_index_in_simdgroup]],
10941080
uint ntg[[threads_per_threadgroup]]) {
1095-
const int64_t i03 = (tgpig) / (ne02*ne01);
1096-
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
1097-
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
1081+
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1082+
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1083+
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
10981084

1099-
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
1100-
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
1101-
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
1085+
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1086+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
1087+
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
11021088

11031089
float slope = 1.0f;
11041090

1105-
if (max_bias > 0.0f) {
1091+
if (args.max_bias > 0.0f) {
11061092
const int64_t h = i02;
11071093

1108-
const float base = h < n_head_log2 ? m0 : m1;
1109-
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
1094+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1095+
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
11101096

11111097
slope = pow(base, exp);
11121098
}
11131099

11141100
// parallel max
11151101
float4 lmax4 = -INFINITY;
11161102

1117-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1118-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1103+
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1104+
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
11191105
}
11201106

11211107
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -1140,8 +1126,8 @@ kernel void kernel_soft_max_4(
11401126

11411127
// parallel sum
11421128
float4 lsum4 = 0.0f;
1143-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1144-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1129+
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1130+
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
11451131
lsum4 += exp_psrc4;
11461132
pdst4[i00] = exp_psrc4;
11471133
}
@@ -1173,7 +1159,7 @@ kernel void kernel_soft_max_4(
11731159

11741160
const float inv_sum = 1.0f/sum;
11751161

1176-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1162+
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
11771163
pdst4[i00] *= inv_sum;
11781164
}
11791165
}

0 commit comments

Comments
 (0)