@@ -975,45 +975,38 @@ kernel void kernel_soft_max(
975975 device const char * src0,
976976 device const char * src1,
977977 device char * dst,
978- constant int64_t & ne00,
979- constant int64_t & ne01,
980- constant int64_t & ne02,
981- constant float & scale,
982- constant float & max_bias,
983- constant float & m0,
984- constant float & m1,
985- constant uint32_t & n_head_log2,
978+ constant ggml_metal_kargs_soft_max & args,
986979 threadgroup float * buf [[threadgroup(0 )]],
987980 uint tgpig[[threadgroup_position_in_grid]],
988981 uint tpitg[[thread_position_in_threadgroup]],
989982 uint sgitg[[simdgroup_index_in_threadgroup]],
990983 uint tiisg[[thread_index_in_simdgroup]],
991984 uint ntg[[threads_per_threadgroup]]) {
992- const int64_t i03 = (tgpig) / (ne02*ne01);
993- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
994- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
985+ const int64_t i03 = (tgpig) / (args. ne02 *args. ne01 );
986+ const int64_t i02 = (tgpig - i03*args. ne02 *args. ne01 ) / args. ne01 ;
987+ const int64_t i01 = (tgpig - i03*args. ne02 *args. ne01 - i02*args. ne01 );
995988
996- device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
997- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr ;
998- device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
989+ device const float * psrc0 = (device const float *) src0 + (i03*args. ne02 *args. ne01 *args. ne00 + i02*args. ne01 *args. ne00 + i01*args. ne00 );
990+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args. ne00 : nullptr ;
991+ device float * pdst = (device float *) dst + (i03*args. ne02 *args. ne01 *args. ne00 + i02*args. ne01 *args. ne00 + i01*args. ne00 );
999992
1000993 float slope = 1 .0f ;
1001994
1002995 // ALiBi
1003- if (max_bias > 0 .0f ) {
996+ if (args. max_bias > 0 .0f ) {
1004997 const int64_t h = i02;
1005998
1006- const float base = h < n_head_log2 ? m0 : m1;
1007- const int exp = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
999+ const float base = h < args. n_head_log2 ? args. m0 : args. m1 ;
1000+ const int exp = h < args. n_head_log2 ? h + 1 : 2 *(h - args. n_head_log2 ) + 1 ;
10081001
10091002 slope = pow (base, exp);
10101003 }
10111004
10121005 // parallel max
10131006 float lmax = -INFINITY;
10141007
1015- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1016- lmax = MAX (lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0 .0f ));
1008+ for (int i00 = tpitg; i00 < args. ne00 ; i00 += ntg) {
1009+ lmax = MAX (lmax, psrc0[i00]*args. scale + (pmask ? slope*pmask[i00] : 0 .0f ));
10171010 }
10181011
10191012 // find the max value in the block
@@ -1037,8 +1030,8 @@ kernel void kernel_soft_max(
10371030
10381031 // parallel sum
10391032 float lsum = 0 .0f ;
1040- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1041- const float exp_psrc0 = exp ((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0 .0f )) - max_val);
1033+ for (int i00 = tpitg; i00 < args. ne00 ; i00 += ntg) {
1034+ const float exp_psrc0 = exp ((psrc0[i00]*args. scale + (pmask ? slope*pmask[i00] : 0 .0f )) - max_val);
10421035 lsum += exp_psrc0;
10431036 pdst[i00] = exp_psrc0;
10441037 }
@@ -1068,7 +1061,7 @@ kernel void kernel_soft_max(
10681061
10691062 const float inv_sum = 1 .0f /sum;
10701063
1071- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1064+ for (int i00 = tpitg; i00 < args. ne00 ; i00 += ntg) {
10721065 pdst[i00] *= inv_sum;
10731066 }
10741067}
@@ -1078,44 +1071,37 @@ kernel void kernel_soft_max_4(
10781071 device const char * src0,
10791072 device const char * src1,
10801073 device char * dst,
1081- constant int64_t & ne00,
1082- constant int64_t & ne01,
1083- constant int64_t & ne02,
1084- constant float & scale,
1085- constant float & max_bias,
1086- constant float & m0,
1087- constant float & m1,
1088- constant uint32_t & n_head_log2,
1074+ constant ggml_metal_kargs_soft_max & args,
10891075 threadgroup float * buf [[threadgroup(0 )]],
10901076 uint tgpig[[threadgroup_position_in_grid]],
10911077 uint tpitg[[thread_position_in_threadgroup]],
10921078 uint sgitg[[simdgroup_index_in_threadgroup]],
10931079 uint tiisg[[thread_index_in_simdgroup]],
10941080 uint ntg[[threads_per_threadgroup]]) {
1095- const int64_t i03 = (tgpig) / (ne02*ne01);
1096- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
1097- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
1081+ const int64_t i03 = (tgpig) / (args. ne02 *args. ne01 );
1082+ const int64_t i02 = (tgpig - i03*args. ne02 *args. ne01 ) / args. ne01 ;
1083+ const int64_t i01 = (tgpig - i03*args. ne02 *args. ne01 - i02*args. ne01 );
10981084
1099- device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4 ;
1100- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr ;
1101- device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4 ;
1085+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*args. ne02 *args. ne01 *args. ne00 + i02*args. ne01 *args. ne00 + i01*args. ne00 )/4 ;
1086+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args. ne00 /4 : nullptr ;
1087+ device float4 * pdst4 = (device float4 *) dst + (i03*args. ne02 *args. ne01 *args. ne00 + i02*args. ne01 *args. ne00 + i01*args. ne00 )/4 ;
11021088
11031089 float slope = 1 .0f ;
11041090
1105- if (max_bias > 0 .0f ) {
1091+ if (args. max_bias > 0 .0f ) {
11061092 const int64_t h = i02;
11071093
1108- const float base = h < n_head_log2 ? m0 : m1;
1109- const int exp = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
1094+ const float base = h < args. n_head_log2 ? args. m0 : args. m1 ;
1095+ const int exp = h < args. n_head_log2 ? h + 1 : 2 *(h - args. n_head_log2 ) + 1 ;
11101096
11111097 slope = pow (base, exp);
11121098 }
11131099
11141100 // parallel max
11151101 float4 lmax4 = -INFINITY;
11161102
1117- for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
1118- lmax4 = fmax (lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0 .0f )));
1103+ for (int i00 = tpitg; i00 < args. ne00 /4 ; i00 += ntg) {
1104+ lmax4 = fmax (lmax4, psrc4[i00]*args. scale + (float4)((pmask ? slope*pmask[i00] : 0 .0f )));
11191105 }
11201106
11211107 const float lmax = MAX (MAX (lmax4[0 ], lmax4[1 ]), MAX (lmax4[2 ], lmax4[3 ]));
@@ -1140,8 +1126,8 @@ kernel void kernel_soft_max_4(
11401126
11411127 // parallel sum
11421128 float4 lsum4 = 0 .0f ;
1143- for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
1144- const float4 exp_psrc4 = exp ((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0 .0f ))) - max_val);
1129+ for (int i00 = tpitg; i00 < args. ne00 /4 ; i00 += ntg) {
1130+ const float4 exp_psrc4 = exp ((psrc4[i00]*args. scale + (float4)((pmask ? slope*pmask[i00] : 0 .0f ))) - max_val);
11451131 lsum4 += exp_psrc4;
11461132 pdst4[i00] = exp_psrc4;
11471133 }
@@ -1173,7 +1159,7 @@ kernel void kernel_soft_max_4(
11731159
11741160 const float inv_sum = 1 .0f /sum;
11751161
1176- for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
1162+ for (int i00 = tpitg; i00 < args. ne00 /4 ; i00 += ntg) {
11771163 pdst4[i00] *= inv_sum;
11781164 }
11791165}
0 commit comments