@@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
10151015 id <MTLBuffer > id_src2 = src2 ? ggml_metal_get_buffer (src2, &offs_src2) : nil ;
10161016 id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
10171017
1018- // GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1019- // if (src0) {
1020- // GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
1021- // ggml_is_contiguous(src0), src0->name);
1022- // }
1023- // if (src1) {
1024- // GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
1025- // ggml_is_contiguous(src1), src1->name);
1026- // }
1027- // if (dst) {
1028- // GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
1029- // dst->name);
1030- // }
1018+ #if 0
1019+ GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1020+ if (src0) {
1021+ GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
1022+ ggml_is_contiguous(src0), src0->name);
1023+ }
1024+ if (src1) {
1025+ GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
1026+ ggml_is_contiguous(src1), src1->name);
1027+ }
1028+ if (dst) {
1029+ GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
1030+ dst->name);
1031+ }
1032+ #endif
10311033
10321034 id <MTLDevice > device = ctx_dev->mtl_device ;
10331035
@@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
18101812 [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
18111813 [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 5 ];
18121814 [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 6 ];
1813- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 7 ];
1814- [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 8 ];
1815- [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 9 ];
1816- [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 10 ];
1817- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 11 ];
1818- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 12 ];
1819- [encoder setBytes: &r2 length: sizeof (r2) atIndex: 13 ];
1820- [encoder setBytes: &r3 length: sizeof (r3) atIndex: 14 ];
1815+ [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 7 ];
1816+ [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 8 ];
1817+ [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 9 ];
1818+ [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 10 ];
1819+ [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 11 ];
1820+ [encoder setBytes: &nb13 length: sizeof (nb13) atIndex: 12 ];
1821+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 13 ];
1822+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 14 ];
1823+ [encoder setBytes: &r2 length: sizeof (r2) atIndex: 15 ];
1824+ [encoder setBytes: &r3 length: sizeof (r3) atIndex: 16 ];
18211825 [encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
18221826 [encoder dispatchThreadgroups: MTLSizeMake ( (ne11 + 31 )/32 , (ne01 + 63 )/64 , ne12*ne13) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
18231827 } else {
@@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
19861990 [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 6 ];
19871991 [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 7 ];
19881992 [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 8 ];
1989- [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 9 ];
1990- [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 10 ];
1991- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 11 ];
1992- [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 12 ];
1993- [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 13 ];
1994- [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 14 ];
1995- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 15 ];
1996- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 16 ];
1997- [encoder setBytes: &r2 length: sizeof (r2) atIndex: 17 ];
1998- [encoder setBytes: &r3 length: sizeof (r3) atIndex: 18 ];
1993+ [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 9 ];
1994+ [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 10 ];
1995+ [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 11 ];
1996+ [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 12 ];
1997+ [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 13 ];
1998+ [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 14 ];
1999+ [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 15 ];
2000+ [encoder setBytes: &nb13 length: sizeof (nb13) atIndex: 16 ];
2001+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 17 ];
2002+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 18 ];
2003+ [encoder setBytes: &r2 length: sizeof (r2) atIndex: 19 ];
2004+ [encoder setBytes: &r3 length: sizeof (r3) atIndex: 20 ];
19992005
20002006 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2001- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2002- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2007+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2008+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
20032009 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
20042010 }
20052011 else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
20482054
20492055 GGML_ASSERT (src1t == GGML_TYPE_F32);
20502056
2057+ GGML_ASSERT (ne03 == 1 );
2058+ GGML_ASSERT (ne13 == 1 );
2059+
20512060 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
20522061 // to the matrix-vector kernel
20532062 // ne20 = n_used_experts
0 commit comments