@@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
10151015    id <MTLBuffer > id_src2 = src2 ? ggml_metal_get_buffer (src2, &offs_src2) : nil ;
10161016    id <MTLBuffer > id_dst  = dst  ? ggml_metal_get_buffer (dst,  &offs_dst)  : nil ;
10171017
1018-     // GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1019-     // if (src0) {
1020-     //     GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
1021-     //             ggml_is_contiguous(src0), src0->name);
1022-     // }
1023-     // if (src1) {
1024-     //     GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
1025-     //             ggml_is_contiguous(src1), src1->name);
1026-     // }
1027-     // if (dst) {
1028-     //     GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
1029-     //             dst->name);
1030-     // }
1018+ #if  0 
1019+     GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1020+     if (src0) {
1021+         GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
1022+                 ggml_is_contiguous(src0), src0->name);
1023+     }
1024+     if (src1) {
1025+         GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
1026+                 ggml_is_contiguous(src1), src1->name);
1027+     }
1028+     if (dst) {
1029+         GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
1030+                 dst->name);
1031+     }
1032+ #endif 
10311033
10321034    id <MTLDevice > device = ctx_dev->mtl_device ;
10331035
@@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
18101812                            [encoder setBytes: &ne02    length: sizeof (ne02) atIndex: 4 ];
18111813                            [encoder setBytes: &nb01    length: sizeof (nb01) atIndex: 5 ];
18121814                            [encoder setBytes: &nb02    length: sizeof (nb02) atIndex: 6 ];
1813-                             [encoder setBytes: &ne12    length: sizeof (ne12) atIndex: 7 ];
1814-                             [encoder setBytes: &nb10    length: sizeof (nb10) atIndex: 8 ];
1815-                             [encoder setBytes: &nb11    length: sizeof (nb11) atIndex: 9 ];
1816-                             [encoder setBytes: &nb12    length: sizeof (nb12) atIndex: 10 ];
1817-                             [encoder setBytes: &ne0     length: sizeof (ne0)  atIndex: 11 ];
1818-                             [encoder setBytes: &ne1     length: sizeof (ne1)  atIndex: 12 ];
1819-                             [encoder setBytes: &r2      length: sizeof (r2)   atIndex: 13 ];
1820-                             [encoder setBytes: &r3      length: sizeof (r3)   atIndex: 14 ];
1815+                             [encoder setBytes: &nb03    length: sizeof (nb03) atIndex: 7 ];
1816+                             [encoder setBytes: &ne12    length: sizeof (ne12) atIndex: 8 ];
1817+                             [encoder setBytes: &nb10    length: sizeof (nb10) atIndex: 9 ];
1818+                             [encoder setBytes: &nb11    length: sizeof (nb11) atIndex: 10 ];
1819+                             [encoder setBytes: &nb12    length: sizeof (nb12) atIndex: 11 ];
1820+                             [encoder setBytes: &nb13    length: sizeof (nb13) atIndex: 12 ];
1821+                             [encoder setBytes: &ne0     length: sizeof (ne0)  atIndex: 13 ];
1822+                             [encoder setBytes: &ne1     length: sizeof (ne1)  atIndex: 14 ];
1823+                             [encoder setBytes: &r2      length: sizeof (r2)   atIndex: 15 ];
1824+                             [encoder setBytes: &r3      length: sizeof (r3)   atIndex: 16 ];
18211825                            [encoder setThreadgroupMemoryLength: 8192  atIndex: 0 ];
18221826                            [encoder dispatchThreadgroups: MTLSizeMake ( (ne11 + 31 )/32 , (ne01 + 63 )/64 , ne12*ne13) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
18231827                        } else  {
@@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
19861990                            [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 6 ];
19871991                            [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 7 ];
19881992                            [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 8 ];
1989-                             [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 9 ];
1990-                             [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 10 ];
1991-                             [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 11 ];
1992-                             [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 12 ];
1993-                             [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 13 ];
1994-                             [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 14 ];
1995-                             [encoder setBytes: &ne0  length: sizeof (ne0)  atIndex: 15 ];
1996-                             [encoder setBytes: &ne1  length: sizeof (ne1)  atIndex: 16 ];
1997-                             [encoder setBytes: &r2   length: sizeof (r2)   atIndex: 17 ];
1998-                             [encoder setBytes: &r3   length: sizeof (r3)   atIndex: 18 ];
1993+                             [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 9 ];
1994+                             [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 10 ];
1995+                             [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 11 ];
1996+                             [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 12 ];
1997+                             [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 13 ];
1998+                             [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 14 ];
1999+                             [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 15 ];
2000+                             [encoder setBytes: &nb13 length: sizeof (nb13) atIndex: 16 ];
2001+                             [encoder setBytes: &ne0  length: sizeof (ne0)  atIndex: 17 ];
2002+                             [encoder setBytes: &ne1  length: sizeof (ne1)  atIndex: 18 ];
2003+                             [encoder setBytes: &r2   length: sizeof (r2)   atIndex: 19 ];
2004+                             [encoder setBytes: &r3   length: sizeof (r3)   atIndex: 20 ];
19992005
20002006                            if  (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
2001-                                      src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
2002-                                      src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2007+                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
2008+                                 src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
20032009                                [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
20042010                            }
20052011                            else  if  (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
20482054
20492055                GGML_ASSERT (src1t == GGML_TYPE_F32);
20502056
2057+                 GGML_ASSERT (ne03 == 1 );
2058+                 GGML_ASSERT (ne13 == 1 );
2059+ 
20512060                //  find the break-even point where the matrix-matrix kernel becomes more efficient compared
20522061                //  to the matrix-vector kernel
20532062                //  ne20 = n_used_experts
0 commit comments