@@ -147,7 +147,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147147
148148enum ggml_metal_kernel_type {
149149 GGML_METAL_KERNEL_TYPE_ADD,
150+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
151+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
152+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
153+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
150154 GGML_METAL_KERNEL_TYPE_ADD_ROW,
155+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2,
156+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4,
157+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6,
158+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8,
151159 GGML_METAL_KERNEL_TYPE_SUB,
152160 GGML_METAL_KERNEL_TYPE_SUB_ROW,
153161 GGML_METAL_KERNEL_TYPE_MUL,
@@ -1129,7 +1137,15 @@ @implementation GGMLMetalClass
11291137 // simd_sum and simd_max requires MTLGPUFamilyApple7
11301138
11311139 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD, add, true );
1140+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true );
1141+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true );
1142+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
1143+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
11321144 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true );
1145+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2, add_row_fuse_2, true );
1146+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4, add_row_fuse_4, true );
1147+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6, add_row_fuse_6, true );
1148+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8, add_row_fuse_8, true );
11331149 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
11341150 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true );
11351151 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
@@ -1875,7 +1891,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18751891 }
18761892}
18771893
1878- static bool ggml_metal_encode_node (
1894+ static int ggml_metal_encode_node (
18791895 ggml_backend_t backend,
18801896 int idx,
18811897 id <MTLComputeCommandEncoder > encoder,
@@ -1885,7 +1901,12 @@ static bool ggml_metal_encode_node(
18851901
18861902 struct ggml_cgraph * gf = ctx->gf ;
18871903
1888- struct ggml_tensor * node = ggml_graph_node (gf, idx);
1904+ enum ggml_op ops[8 ];
1905+
1906+ struct ggml_tensor ** nodes = ggml_graph_nodes (gf);
1907+ struct ggml_tensor * node = nodes[idx];
1908+
1909+ struct ggml_tensor ** fuse = nodes + idx + 1 ;
18891910
18901911 // GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
18911912
@@ -1895,7 +1916,7 @@ static bool ggml_metal_encode_node(
18951916 struct ggml_tensor * dst = node;
18961917
18971918 if (ggml_is_empty (dst)) {
1898- return true ;
1919+ return 1 ;
18991920 }
19001921
19011922 switch (dst->op ) {
@@ -1906,7 +1927,7 @@ static bool ggml_metal_encode_node(
19061927 case GGML_OP_PERMUTE:
19071928 {
19081929 // noop -> next node
1909- } return true ;
1930+ } return 1 ;
19101931 default :
19111932 {
19121933 } break ;
@@ -1973,6 +1994,8 @@ static bool ggml_metal_encode_node(
19731994 id <MTLBuffer > id_src2 = src2 ? ggml_metal_get_buffer (src2, &offs_src2) : nil ;
19741995 id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
19751996
1997+ int n_fuse = 1 ;
1998+
19761999#if 0
19772000 GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
19782001 if (src0) {
@@ -2050,14 +2073,50 @@ static bool ggml_metal_encode_node(
20502073
20512074 id <MTLComputePipelineState > pipeline = nil ;
20522075
2076+ {
2077+ ops[0 ] = GGML_OP_ADD;
2078+ ops[1 ] = GGML_OP_ADD;
2079+ ops[2 ] = GGML_OP_ADD;
2080+ ops[3 ] = GGML_OP_ADD;
2081+ ops[4 ] = GGML_OP_ADD;
2082+ ops[5 ] = GGML_OP_ADD;
2083+ ops[6 ] = GGML_OP_ADD;
2084+ ops[7 ] = GGML_OP_ADD;
2085+
2086+ for (n_fuse = 8 ; n_fuse > 1 ; --n_fuse) {
2087+ if (n_fuse % 2 == 1 ) {
2088+ continue ;
2089+ }
2090+ if (ggml_can_fuse (gf, idx, ops, n_fuse)) {
2091+ if (ggml_are_same_layout (node->src [1 ], fuse[0 ]->src [1 ]) &&
2092+ ggml_are_same_layout (node->src [1 ], fuse[n_fuse - 2 ]->src [1 ])) {
2093+ break ;
2094+ }
2095+ }
2096+ }
2097+ }
2098+
20532099 if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
20542100 GGML_ASSERT (ggml_is_contiguous (src0));
20552101
20562102 // src1 is a row
20572103 GGML_ASSERT (ne11 == 1 );
20582104
20592105 switch (dst->op ) {
2060- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ; break ;
2106+ case GGML_OP_ADD:
2107+ {
2108+ switch (n_fuse) {
2109+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2].pipeline ; break ;
2110+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4].pipeline ; break ;
2111+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6].pipeline ; break ;
2112+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8].pipeline ; break ;
2113+ default :
2114+ {
2115+ GGML_ASSERT (n_fuse == 1 );
2116+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ;
2117+ }
2118+ }
2119+ } break ;
20612120 case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
20622121 case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline ; break ;
20632122 case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline ; break ;
@@ -2067,7 +2126,21 @@ static bool ggml_metal_encode_node(
20672126 bcast_row = true ;
20682127 } else {
20692128 switch (dst->op ) {
2070- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ; break ;
2129+ case GGML_OP_ADD:
2130+ {
2131+ // GGML_LOG_INFO("XXXXXXXXXXXXXXXXXXXXXXXXX n_fuse = %d\n", n_fuse);
2132+ switch (n_fuse) {
2133+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline ; break ;
2134+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline ; break ;
2135+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline ; break ;
2136+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline ; break ;
2137+ default :
2138+ {
2139+ GGML_ASSERT (n_fuse == 1 );
2140+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ; break ;
2141+ }
2142+ }
2143+ } break ;
20712144 case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
20722145 case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
20732146 case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
@@ -2107,7 +2180,16 @@ static bool ggml_metal_encode_node(
21072180 [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
21082181 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
21092182 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2110- [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
2183+ for (int f = 0 ; f < n_fuse - 1 ; ++f) {
2184+ id_src1 = ggml_metal_get_buffer (fuse[f]->src [1 ], &offs_src1);
2185+
2186+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 3 + f];
2187+
2188+ if (f + 1 == n_fuse - 1 ) {
2189+ id_dst = ggml_metal_get_buffer (fuse[f], &offs_dst);
2190+ }
2191+ }
2192+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 + n_fuse];
21112193
21122194 if (bcast_row) {
21132195 const int64_t n = ggml_nelements (dst)/4 ;
@@ -2674,7 +2756,7 @@ static bool ggml_metal_encode_node(
26742756 id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
26752757 if (!h_src0) {
26762758 GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2677- return false ;
2759+ return 0 ;
26782760 }
26792761
26802762 offs_src0 = 0;
@@ -3550,7 +3632,7 @@ static bool ggml_metal_encode_node(
35503632 id <MTLBuffer > h_src1 = ggml_metal_mem_pool_alloc (mem_pool, s_src1);
35513633 if (!h_src1) {
35523634 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_src1);
3553- return false ;
3635+ return 0 ;
35543636 }
35553637
35563638 const int64_t neh0 = ne0;
@@ -3566,15 +3648,15 @@ static bool ggml_metal_encode_node(
35663648 id <MTLBuffer > h_dst = ggml_metal_mem_pool_alloc (mem_pool, s_dst);
35673649 if (!h_dst) {
35683650 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_dst);
3569- return false ;
3651+ return 0 ;
35703652 }
35713653
35723654 // tokens per expert
35733655 const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
35743656 id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
35753657 if (!h_tpe) {
35763658 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tpe);
3577- return false ;
3659+ return 0 ;
35783660 }
35793661
35803662 // id map
@@ -3583,7 +3665,7 @@ static bool ggml_metal_encode_node(
35833665 id <MTLBuffer > h_ids = ggml_metal_mem_pool_alloc (mem_pool, s_ids);
35843666 if (!h_ids) {
35853667 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_ids);
3586- return false ;
3668+ return 0 ;
35873669 }
35883670
35893671 {
@@ -5442,7 +5524,7 @@ static bool ggml_metal_encode_node(
54425524 }
54435525 }
54445526
5445- return true ;
5527+ return n_fuse ;
54465528}
54475529
54485530static enum ggml_status ggml_metal_graph_compute (
@@ -5948,20 +6030,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
59486030 struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs [cb_idx].mem_pool ;
59496031 ggml_metal_mem_pool_reset (mem_pool);
59506032
5951- for (int idx = node_start; idx < node_end; ++idx ) {
6033+ for (int idx = node_start; idx < node_end;) {
59526034 if (should_capture) {
59536035 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
59546036 }
59556037
5956- const bool res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
6038+ const int res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
59576039
59586040 if (should_capture) {
59596041 [encoder popDebugGroup ];
59606042 }
59616043
5962- if (! res) {
6044+ if (res == 0 ) {
59636045 break ;
59646046 }
6047+
6048+ idx += res;
59656049 }
59666050
59676051 [encoder endEncoding ];
0 commit comments