@@ -147,7 +147,21 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147147
148148enum ggml_metal_kernel_type {
149149 GGML_METAL_KERNEL_TYPE_ADD,
150+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
151+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
152+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
153+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
154+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
155+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
156+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
150157 GGML_METAL_KERNEL_TYPE_ADD_ROW,
158+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2,
159+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3,
160+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4,
161+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5,
162+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6,
163+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7,
164+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8,
151165 GGML_METAL_KERNEL_TYPE_SUB,
152166 GGML_METAL_KERNEL_TYPE_SUB_ROW,
153167 GGML_METAL_KERNEL_TYPE_MUL,
@@ -1129,7 +1143,21 @@ @implementation GGMLMetalClass
11291143 // simd_sum and simd_max requires MTLGPUFamilyApple7
11301144
11311145 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD, add, true );
1146+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true );
1147+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true );
1148+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true );
1149+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true );
1150+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
1151+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true );
1152+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
11321153 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true );
1154+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2, add_row_fuse_2, true );
1155+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3, add_row_fuse_3, true );
1156+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4, add_row_fuse_4, true );
1157+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5, add_row_fuse_5, true );
1158+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6, add_row_fuse_6, true );
1159+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7, add_row_fuse_7, true );
1160+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8, add_row_fuse_8, true );
11331161 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
11341162 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true );
11351163 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
@@ -1875,7 +1903,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18751903 }
18761904}
18771905
1878- static bool ggml_metal_encode_node (
1906+ static int ggml_metal_encode_node (
18791907 ggml_backend_t backend,
18801908 int idx,
18811909 id <MTLComputeCommandEncoder > encoder,
@@ -1885,7 +1913,10 @@ static bool ggml_metal_encode_node(
18851913
18861914 struct ggml_cgraph * gf = ctx->gf ;
18871915
1888- struct ggml_tensor * node = ggml_graph_node (gf, idx);
1916+ enum ggml_op ops[8 ];
1917+
1918+ struct ggml_tensor ** nodes = ggml_graph_nodes (gf) + idx;
1919+ struct ggml_tensor * node = nodes[0 ];
18891920
18901921 // GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
18911922
@@ -1895,7 +1926,7 @@ static bool ggml_metal_encode_node(
18951926 struct ggml_tensor * dst = node;
18961927
18971928 if (ggml_is_empty (dst)) {
1898- return true ;
1929+ return 1 ;
18991930 }
19001931
19011932 switch (dst->op ) {
@@ -1906,7 +1937,7 @@ static bool ggml_metal_encode_node(
19061937 case GGML_OP_PERMUTE:
19071938 {
19081939 // noop -> next node
1909- } return true ;
1940+ } return 1 ;
19101941 default :
19111942 {
19121943 } break ;
@@ -1973,6 +2004,8 @@ static bool ggml_metal_encode_node(
19732004 id <MTLBuffer > id_src2 = src2 ? ggml_metal_get_buffer (src2, &offs_src2) : nil ;
19742005 id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
19752006
2007+ int n_fuse = 1 ;
2008+
19762009#if 0
19772010 GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
19782011 if (src0) {
@@ -2050,31 +2083,6 @@ static bool ggml_metal_encode_node(
20502083
20512084 id <MTLComputePipelineState > pipeline = nil ;
20522085
2053- if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
2054- GGML_ASSERT (ggml_is_contiguous (src0));
2055-
2056- // src1 is a row
2057- GGML_ASSERT (ne11 == 1 );
2058-
2059- switch (dst->op ) {
2060- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ; break ;
2061- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
2062- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline ; break ;
2063- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline ; break ;
2064- default : GGML_ABORT (" fatal error" );
2065- }
2066-
2067- bcast_row = true ;
2068- } else {
2069- switch (dst->op ) {
2070- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ; break ;
2071- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
2072- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
2073- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
2074- default : GGML_ABORT (" fatal error" );
2075- }
2076- }
2077-
20782086 ggml_metal_kargs_bin args = {
20792087 /* .ne00 =*/ ne00,
20802088 /* .ne01 =*/ ne01,
@@ -2101,12 +2109,106 @@ static bool ggml_metal_encode_node(
21012109 /* .nb2 =*/ nb2,
21022110 /* .nb3 =*/ nb3,
21032111 /* .offs =*/ offs,
2112+ /* .o1 =*/ { offs_src1 },
21042113 };
21052114
2115+ {
2116+ ops[0 ] = GGML_OP_ADD;
2117+ ops[1 ] = GGML_OP_ADD;
2118+ ops[2 ] = GGML_OP_ADD;
2119+ ops[3 ] = GGML_OP_ADD;
2120+ ops[4 ] = GGML_OP_ADD;
2121+ ops[5 ] = GGML_OP_ADD;
2122+ ops[6 ] = GGML_OP_ADD;
2123+ ops[7 ] = GGML_OP_ADD;
2124+
2125+ size_t offs_fuse;
2126+ id <MTLBuffer > id_fuse;
2127+
2128+ for (n_fuse = 0 ; n_fuse <= 6 ; ++n_fuse) {
2129+ if (!ggml_can_fuse (gf, idx + n_fuse, ops + n_fuse, 2 )) {
2130+ break ;
2131+ }
2132+
2133+ if (!ggml_are_same_layout (nodes[n_fuse]->src [1 ], nodes[n_fuse + 1 ]->src [1 ])) {
2134+ break ;
2135+ }
2136+
2137+ // only fuse nodes if src1 is in the same Metal buffer
2138+ id_fuse = ggml_metal_get_buffer (nodes[n_fuse + 1 ]->src [1 ], &offs_fuse);
2139+ if (id_fuse != id_src1) {
2140+ break ;
2141+ }
2142+
2143+ args.o1 [n_fuse + 1 ] = offs_fuse;
2144+ }
2145+
2146+ ++n_fuse;
2147+ }
2148+
2149+ if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
2150+ GGML_ASSERT (ggml_is_contiguous (src0));
2151+
2152+ // src1 is a row
2153+ GGML_ASSERT (ne11 == 1 );
2154+
2155+ switch (dst->op ) {
2156+ case GGML_OP_ADD:
2157+ {
2158+ switch (n_fuse) {
2159+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW ].pipeline ; break ;
2160+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2].pipeline ; break ;
2161+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3].pipeline ; break ;
2162+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4].pipeline ; break ;
2163+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5].pipeline ; break ;
2164+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6].pipeline ; break ;
2165+ case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7].pipeline ; break ;
2166+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8].pipeline ; break ;
2167+ default : GGML_ABORT (" fatal error" );
2168+ }
2169+ } break ;
2170+ case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
2171+ case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline ; break ;
2172+ case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline ; break ;
2173+ default : GGML_ABORT (" fatal error" );
2174+ }
2175+
2176+ bcast_row = true ;
2177+ } else {
2178+ switch (dst->op ) {
2179+ case GGML_OP_ADD:
2180+ {
2181+ switch (n_fuse) {
2182+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD ].pipeline ; break ;
2183+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline ; break ;
2184+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline ; break ;
2185+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline ; break ;
2186+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline ; break ;
2187+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline ; break ;
2188+ case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline ; break ;
2189+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline ; break ;
2190+ default : GGML_ABORT (" fatal error" );
2191+ }
2192+ } break ;
2193+ case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
2194+ case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
2195+ case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
2196+ default : GGML_ABORT (" fatal error" );
2197+ }
2198+ }
2199+
2200+ if (n_fuse > 1 ) {
2201+ id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
2202+ }
2203+
21062204 [encoder setComputePipelineState: pipeline];
21072205 [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
21082206 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2109- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2207+ if (dst->op == GGML_OP_ADD) {
2208+ [encoder setBuffer: id_src1 offset: 0 atIndex: 2 ];
2209+ } else {
2210+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2211+ }
21102212 [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
21112213
21122214 if (bcast_row) {
@@ -2239,6 +2341,7 @@ static bool ggml_metal_encode_node(
22392341 /* .nb2 =*/ pnb2,
22402342 /* .nb3 =*/ pnb3,
22412343 /* .offs =*/ offs,
2344+ /* .o1 =*/ { offs_src1 },
22422345 };
22432346
22442347 [encoder setComputePipelineState: pipeline];
@@ -2674,7 +2777,7 @@ static bool ggml_metal_encode_node(
26742777 id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
26752778 if (!h_src0) {
26762779 GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2677- return false ;
2780+ return 0 ;
26782781 }
26792782
26802783 offs_src0 = 0;
@@ -3550,7 +3653,7 @@ static bool ggml_metal_encode_node(
35503653 id <MTLBuffer > h_src1 = ggml_metal_mem_pool_alloc (mem_pool, s_src1);
35513654 if (!h_src1) {
35523655 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_src1);
3553- return false ;
3656+ return 0 ;
35543657 }
35553658
35563659 const int64_t neh0 = ne0;
@@ -3566,15 +3669,15 @@ static bool ggml_metal_encode_node(
35663669 id <MTLBuffer > h_dst = ggml_metal_mem_pool_alloc (mem_pool, s_dst);
35673670 if (!h_dst) {
35683671 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_dst);
3569- return false ;
3672+ return 0 ;
35703673 }
35713674
35723675 // tokens per expert
35733676 const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
35743677 id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
35753678 if (!h_tpe) {
35763679 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tpe);
3577- return false ;
3680+ return 0 ;
35783681 }
35793682
35803683 // id map
@@ -3583,7 +3686,7 @@ static bool ggml_metal_encode_node(
35833686 id <MTLBuffer > h_ids = ggml_metal_mem_pool_alloc (mem_pool, s_ids);
35843687 if (!h_ids) {
35853688 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_ids);
3586- return false ;
3689+ return 0 ;
35873690 }
35883691
35893692 {
@@ -5442,7 +5545,7 @@ static bool ggml_metal_encode_node(
54425545 }
54435546 }
54445547
5445- return true ;
5548+ return n_fuse ;
54465549}
54475550
54485551static enum ggml_status ggml_metal_graph_compute (
@@ -5948,20 +6051,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
59486051 struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs [cb_idx].mem_pool ;
59496052 ggml_metal_mem_pool_reset (mem_pool);
59506053
5951- for (int idx = node_start; idx < node_end; ++idx ) {
6054+ for (int idx = node_start; idx < node_end;) {
59526055 if (should_capture) {
59536056 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
59546057 }
59556058
5956- const bool res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
6059+ const int res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
59576060
59586061 if (should_capture) {
59596062 [encoder popDebugGroup ];
59606063 }
59616064
5962- if (! res) {
6065+ if (res == 0 ) {
59636066 break ;
59646067 }
6068+
6069+ idx += res;
59656070 }
59666071
59676072 [encoder endEncoding ];
0 commit comments