@@ -1893,7 +1893,7 @@ void kernel_mul_mv_impl(
18931893
18941894 float sumf = 0 ;
18951895 for (int i = tiisg; i < args.ne00 /4 ; i += 32 ) {
1896- sumf += dot ((T14 ) x4[i], y4[i]);
1896+ sumf += dot ((float4 ) x4[i], (float4) y4[i]);
18971897 }
18981898
18991899 float all_sum = simd_sum (sumf);
@@ -3876,55 +3876,31 @@ kernel void kernel_cpy_f32_iq4_nl(
38763876}
38773877
38783878kernel void kernel_concat (
3879+ constant ggml_metal_kargs_concat & args,
38793880 device const char * src0,
38803881 device const char * src1,
38813882 device char * dst,
3882- constant int64_t & ne00,
3883- constant int64_t & ne01,
3884- constant int64_t & ne02,
3885- constant int64_t & ne03,
3886- constant uint64_t & nb00,
3887- constant uint64_t & nb01,
3888- constant uint64_t & nb02,
3889- constant uint64_t & nb03,
3890- constant int64_t & ne10,
3891- constant int64_t & ne11,
3892- constant int64_t & ne12,
3893- constant int64_t & ne13,
3894- constant uint64_t & nb10,
3895- constant uint64_t & nb11,
3896- constant uint64_t & nb12,
3897- constant uint64_t & nb13,
3898- constant int64_t & ne0,
3899- constant int64_t & ne1,
3900- constant int64_t & ne2,
3901- constant int64_t & ne3,
3902- constant uint64_t & nb0,
3903- constant uint64_t & nb1,
3904- constant uint64_t & nb2,
3905- constant uint64_t & nb3,
3906- constant int32_t & dim,
3907- uint3 tgpig[[threadgroup_position_in_grid]],
3908- uint3 tpitg[[thread_position_in_threadgroup]],
3909- uint3 ntg[[threads_per_threadgroup]]) {
3883+ uint3 tgpig[[threadgroup_position_in_grid]],
3884+ ushort3 tpitg[[thread_position_in_threadgroup]],
3885+ ushort3 ntg[[threads_per_threadgroup]]) {
39103886
3911- const int64_t i3 = tgpig.z ;
3912- const int64_t i2 = tgpig.y ;
3913- const int64_t i1 = tgpig.x ;
3887+ const int i3 = tgpig.z ;
3888+ const int i2 = tgpig.y ;
3889+ const int i1 = tgpig.x ;
39143890
3915- int64_t o[4 ] = {0 , 0 , 0 , 0 };
3916- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
3891+ int o[4 ] = {0 , 0 , 0 , 0 };
3892+ o[args. dim ] = args. dim == 0 ? args. ne00 : (args. dim == 1 ? args. ne01 : (args. dim == 2 ? args. ne02 : args. ne03 ));
39173893
39183894 device const float * x;
39193895
3920- for (int i0 = tpitg.x ; i0 < ne0; i0 += ntg.x ) {
3921- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3922- x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
3896+ for (int i0 = tpitg.x ; i0 < args. ne0 ; i0 += ntg.x ) {
3897+ if (i0 < args. ne00 && i1 < args. ne01 && i2 < args. ne02 && i3 < args. ne03 ) {
3898+ x = (device const float *)(src0 + (i3 )*args. nb03 + (i2 )*args. nb02 + (i1 )*args. nb01 + (i0 )*args. nb00 );
39233899 } else {
3924- x = (device const float *)(src1 + (i3 - o[3 ])*nb13 + (i2 - o[2 ])*nb12 + (i1 - o[1 ])*nb11 + (i0 - o[0 ])*nb10);
3900+ x = (device const float *)(src1 + (i3 - o[3 ])*args. nb13 + (i2 - o[2 ])*args. nb12 + (i1 - o[1 ])*args. nb11 + (i0 - o[0 ])*args. nb10 );
39253901 }
39263902
3927- device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3903+ device float * y = (device float *)(dst + i3*args. nb3 + i2*args. nb2 + i1*args. nb1 + i0*args. nb0 );
39283904
39293905 *y = *x;
39303906 }
0 commit comments