@@ -821,13 +821,23 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
821821
822822 // the callback given to the thread pool
823823 void (^encode_async)(size_t ith);
824+ void (^encode_next)(void );
824825
825826 // n_cb command buffers + 1 used by the main thread
826827 struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1 ];
828+ struct ggml_metal_command_buffer cmd_bufs_next[2 ];
827829
828830 // abort ggml_metal_graph_compute if callback returns true
829831 ggml_abort_callback abort_callback;
830832 void * abort_callback_data;
833+
834+ // reuse info
835+ int i_next;
836+
837+ int n_nodes_max;
838+ int n_nodes_prev;
839+
840+ struct ggml_tensor * cg_nodes;
831841};
832842
833843// MSL code
@@ -1084,13 +1094,21 @@ @implementation GGMLMetalClass
10841094
10851095 ctx->gf = nil ;
10861096 ctx->encode_async = nil ;
1097+ ctx->encode_next = nil ;
10871098 for (int i = 0 ; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
10881099 ctx->cmd_bufs [i].obj = nil ;
10891100
10901101 ctx->cmd_bufs [i].mem_pool = ggml_metal_mem_pool_init ();
10911102 ctx->cmd_bufs [i].mem_pool ->device = device;
10921103 }
10931104
1105+ for (int i = 0 ; i < 2 ; ++i) {
1106+ ctx->cmd_bufs_next [i].obj = nil ;
1107+
1108+ ctx->cmd_bufs_next [i].mem_pool = ggml_metal_mem_pool_init ();
1109+ ctx->cmd_bufs_next [i].mem_pool ->device = device;
1110+ }
1111+
10941112#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
10951113 if (@available (macOS 10.12 , iOS 16.0 , *)) {
10961114 GGML_LOG_INFO (" %s : recommendedMaxWorkingSetSize = %8.2f MB\n " , __func__, device.recommendedMaxWorkingSetSize / 1e6 );
@@ -1521,6 +1539,13 @@ @implementation GGMLMetalClass
15211539 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true );
15221540 }
15231541
1542+ ctx->i_next = 0 ;
1543+
1544+ ctx->n_nodes_max = 16384 ;
1545+ ctx->n_nodes_prev = -1 ;
1546+
1547+ ctx->cg_nodes = ggml_aligned_malloc (ctx->n_nodes_max * sizeof (struct ggml_tensor));
1548+
15241549 return ctx;
15251550}
15261551
@@ -1532,6 +1557,7 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
15321557 }
15331558
15341559 Block_release (ctx->encode_async );
1560+ Block_release (ctx->encode_next );
15351561
15361562 [ctx->queue release ];
15371563
@@ -1541,8 +1567,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
15411567 ggml_metal_mem_pool_free (ctx->cmd_bufs [i].mem_pool );
15421568 }
15431569
1570+ ggml_metal_mem_pool_free (ctx->cmd_bufs_next [0 ].mem_pool );
1571+ ggml_metal_mem_pool_free (ctx->cmd_bufs_next [1 ].mem_pool );
1572+
15441573 dispatch_release (ctx->d_queue );
15451574
1575+ ggml_aligned_free (ctx->cg_nodes , ctx->n_nodes_max * sizeof (struct ggml_tensor));
1576+
15461577 free (ctx);
15471578}
15481579
@@ -5448,6 +5479,39 @@ static enum ggml_status ggml_metal_graph_compute(
54485479 struct ggml_backend_metal_context * ctx = backend->context ;
54495480 struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
54505481
5482+ // const int64_t t_start = ggml_time_us();
5483+
5484+ // ///////////////////////////////////////////////////
5485+ // hacky way to determine that the graph is the same as the previous one
5486+ //
5487+ bool can_reuse = true ;
5488+
5489+ if (gf->n_nodes > ctx->n_nodes_max ) {
5490+ can_reuse = false ;
5491+ }
5492+
5493+ if (gf->n_nodes != ctx->n_nodes_prev ) {
5494+ can_reuse = false ;
5495+ }
5496+
5497+ if (can_reuse) {
5498+ for (int i = 0 ; i < gf->n_nodes ; ++i) {
5499+ if (memcmp (gf->nodes [i], ctx->cg_nodes + i, sizeof (struct ggml_tensor)) != 0 ) {
5500+ can_reuse = false ;
5501+ break ;
5502+ }
5503+ }
5504+ }
5505+
5506+ if (!can_reuse) {
5507+ ctx->n_nodes_prev = gf->n_nodes ;
5508+
5509+ for (int i = 0 ; i < gf->n_nodes ; ++i) {
5510+ memcpy (ctx->cg_nodes + i, gf->nodes [i], sizeof (struct ggml_tensor));
5511+ }
5512+ }
5513+ // ////////////////////////////////////////////////////
5514+
54515515 // number of nodes encoded by the main thread (empirically determined)
54525516 const int n_main = 128 ;
54535517
@@ -5492,78 +5556,126 @@ static enum ggml_status ggml_metal_graph_compute(
54925556 }
54935557 }
54945558
5495- // the main thread commits the first few commands immediately
5496- // cmd_buf[n_cb]
5497- {
5498- id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBufferWithUnretainedReferences ];
5499- ctx->cmd_bufs [n_cb].obj = cmd_buf;
5500-
5501- [cmd_buf enqueue ];
5502- ctx->encode_async (n_cb);
5503- }
5504-
5505- // prepare the rest of the command buffers asynchronously
5506- // cmd_buf[0.. n_cb)
5507- for (int cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
5508- id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBufferWithUnretainedReferences ];
5509- ctx->cmd_bufs [cb_idx].obj = cmd_buf;
5559+ if (!can_reuse) {
5560+ // the main thread commits the first few commands immediately
5561+ // cmd_buf[n_cb]
5562+ {
5563+ id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBufferWithUnretainedReferences ];
5564+ ctx->cmd_bufs [n_cb].obj = cmd_buf;
55105565
5511- // always enqueue the first two command buffers
5512- // enqueue all of the command buffers if we don't need to abort
5513- if (cb_idx < 2 || ctx->abort_callback == NULL ) {
55145566 [cmd_buf enqueue ];
5567+ ctx->encode_async (n_cb);
55155568 }
5516- }
5517-
5518- dispatch_apply (n_cb, ctx->d_queue , ctx->encode_async );
55195569
5520- // wait for completion and check status of each command buffer
5521- // needed to detect if the device ran out-of-memory for example (#1881 )
5522- {
5523- id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [n_cb]. obj ;
5524- [cmd_buf waitUntilCompleted ] ;
5570+ // prepare the rest of the command buffers asynchronously
5571+ // cmd_buf[0.. n_cb )
5572+ for ( int cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
5573+ id <MTLCommandBuffer > cmd_buf = [ ctx->queue commandBufferWithUnretainedReferences ] ;
5574+ ctx-> cmd_bufs [cb_idx]. obj = cmd_buf ;
55255575
5526- MTLCommandBufferStatus status = [cmd_buf status ];
5527- if (status != MTLCommandBufferStatusCompleted ) {
5528- GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, n_cb, status);
5529- if (status == MTLCommandBufferStatusError ) {
5530- GGML_LOG_INFO (" error: %s \n " , [[cmd_buf error ].localizedDescription UTF8String ]);
5576+ // always enqueue the first two command buffers
5577+ // enqueue all of the command buffers if we don't need to abort
5578+ if (cb_idx < 2 || ctx->abort_callback == NULL ) {
5579+ [cmd_buf enqueue ];
55315580 }
5532-
5533- return GGML_STATUS_FAILED;
55345581 }
5535- }
55365582
5537- for (int i = 0 ; i < n_cb; ++i) {
5538- id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [i].obj ;
5539- [cmd_buf waitUntilCompleted ];
5583+ dispatch_apply (n_cb, ctx->d_queue , ctx->encode_async );
55405584
5541- MTLCommandBufferStatus status = [cmd_buf status ];
5542- if (status != MTLCommandBufferStatusCompleted ) {
5543- GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, i, status);
5544- if (status == MTLCommandBufferStatusError ) {
5545- GGML_LOG_INFO (" error: %s \n " , [[cmd_buf error ].localizedDescription UTF8String ]);
5585+ // encode the command buffer for the next iter while the GPU has already started
5586+ {
5587+ id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBufferWithUnretainedReferences ];
5588+ [cmd_buf retain ];
5589+ if (ctx->cmd_bufs_next [ctx->i_next].obj != nil ) {
5590+ [ctx->cmd_bufs_next[ctx->i_next].obj release ];
55465591 }
5592+ ctx->cmd_bufs_next [ctx->i_next].obj = cmd_buf;
55475593
5548- return GGML_STATUS_FAILED ;
5594+ ctx-> encode_next () ;
55495595 }
55505596
5551- id <MTLCommandBuffer > next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs [i + 1 ].obj : nil );
5552- if (!next_buffer) {
5553- continue ;
5597+ // wait for completion and check status of each command buffer
5598+ // needed to detect if the device ran out-of-memory for example (#1881)
5599+ {
5600+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [n_cb].obj ;
5601+ [cmd_buf waitUntilCompleted ];
5602+
5603+ MTLCommandBufferStatus status = [cmd_buf status ];
5604+ if (status != MTLCommandBufferStatusCompleted ) {
5605+ GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, n_cb, status);
5606+ if (status == MTLCommandBufferStatusError ) {
5607+ GGML_LOG_INFO (" error: %s \n " , [[cmd_buf error ].localizedDescription UTF8String ]);
5608+ }
5609+
5610+ return GGML_STATUS_FAILED;
5611+ }
55545612 }
55555613
5556- const bool next_queued = ([next_buffer status ] != MTLCommandBufferStatusNotEnqueued );
5557- if (next_queued) {
5558- continue ;
5614+ for (int i = 0 ; i < n_cb; ++i) {
5615+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [i].obj ;
5616+ [cmd_buf waitUntilCompleted ];
5617+
5618+ MTLCommandBufferStatus status = [cmd_buf status ];
5619+ if (status != MTLCommandBufferStatusCompleted ) {
5620+ GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, i, status);
5621+ if (status == MTLCommandBufferStatusError ) {
5622+ GGML_LOG_INFO (" error: %s \n " , [[cmd_buf error ].localizedDescription UTF8String ]);
5623+ }
5624+
5625+ return GGML_STATUS_FAILED;
5626+ }
5627+
5628+ id <MTLCommandBuffer > next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs [i + 1 ].obj : nil );
5629+ if (!next_buffer) {
5630+ continue ;
5631+ }
5632+
5633+ const bool next_queued = ([next_buffer status ] != MTLCommandBufferStatusNotEnqueued );
5634+ if (next_queued) {
5635+ continue ;
5636+ }
5637+
5638+ if (ctx->abort_callback && ctx->abort_callback (ctx->abort_callback_data )) {
5639+ GGML_LOG_INFO (" %s : command buffer %d aborted" , __func__, i);
5640+ return GGML_STATUS_ABORTED;
5641+ }
5642+
5643+ [next_buffer commit ];
55595644 }
5645+ } else {
5646+ struct ggml_metal_command_buffer cmd_buf_cur = ctx->cmd_bufs_next [(ctx->i_next + 1 )%2 ];
5647+
5648+ // directly submit the command buffer that we have prepared in the previous iteration
5649+ [ctx->cmd_bufs_next[(ctx->i_next + 1 )%2 ].obj commit ];
55605650
5561- if (ctx->abort_callback && ctx->abort_callback (ctx->abort_callback_data )) {
5562- GGML_LOG_INFO (" %s : command buffer %d aborted" , __func__, i);
5563- return GGML_STATUS_ABORTED;
5651+ // encode the command buffer for the next iter
5652+ {
5653+ id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBufferWithUnretainedReferences ];
5654+ [cmd_buf retain ];
5655+ if (ctx->cmd_bufs_next [ctx->i_next].obj != nil ) {
5656+ [ctx->cmd_bufs_next[ctx->i_next].obj release ];
5657+ }
5658+ ctx->cmd_bufs_next [ctx->i_next].obj = cmd_buf;
5659+
5660+ ctx->encode_next ();
55645661 }
55655662
5566- [next_buffer commit ];
5663+ // wait for completion and check status of each command buffer
5664+ // needed to detect if the device ran out-of-memory for example (#1881)
5665+ {
5666+ id <MTLCommandBuffer > cmd_buf = cmd_buf_cur.obj ;
5667+ [cmd_buf waitUntilCompleted ];
5668+
5669+ MTLCommandBufferStatus status = [cmd_buf status ];
5670+ if (status != MTLCommandBufferStatusCompleted ) {
5671+ GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, ctx->i_next , status);
5672+ if (status == MTLCommandBufferStatusError ) {
5673+ GGML_LOG_INFO (" error: %s \n " , [[cmd_buf error ].localizedDescription UTF8String ]);
5674+ }
5675+
5676+ return GGML_STATUS_FAILED;
5677+ }
5678+ }
55675679 }
55685680
55695681 if (!should_capture && ctx->capture_started ) {
@@ -5572,6 +5684,8 @@ static enum ggml_status ggml_metal_graph_compute(
55725684 }
55735685 }
55745686
5687+ // printf(" time = %.3f ms\n", (float)(ggml_time_us() - t_start)/1000.0f);
5688+
55755689 return GGML_STATUS_SUCCESS;
55765690}
55775691
@@ -5919,6 +6033,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
59196033 Block_release (ctx->encode_async );
59206034 }
59216035
6036+ if (ctx->encode_next ) {
6037+ Block_release (ctx->encode_next );
6038+ }
6039+
59226040 ctx->encode_async = Block_copy (^(size_t iter) {
59236041 const int cb_idx = iter;
59246042 const int n_cb_l = ctx->n_cb ;
@@ -5967,6 +6085,40 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
59676085 [cmd_buf commit ];
59686086 }
59696087 });
6088+
6089+ ctx->encode_next = Block_copy (^(void ) {
6090+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs_next [ctx->i_next].obj ;
6091+
6092+ id <MTLComputeCommandEncoder > encoder = [cmd_buf computeCommandEncoder ];
6093+
6094+ int node_start = 0 ;
6095+ int node_end = ctx->gf ->n_nodes ;
6096+
6097+ const bool should_capture = ctx->capture_next_compute ;
6098+
6099+ struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs_next [ctx->i_next].mem_pool ;
6100+ ggml_metal_mem_pool_reset (mem_pool);
6101+
6102+ for (int idx = node_start; idx < node_end; ++idx) {
6103+ if (should_capture) {
6104+ [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
6105+ }
6106+
6107+ const bool res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
6108+
6109+ if (should_capture) {
6110+ [encoder popDebugGroup ];
6111+ }
6112+
6113+ if (!res) {
6114+ break ;
6115+ }
6116+ }
6117+
6118+ [encoder endEncoding ];
6119+
6120+ ctx->i_next = (ctx->i_next + 1 ) % 2 ;
6121+ });
59706122}
59716123
59726124static struct ggml_backend_i ggml_backend_metal_i = {
0 commit comments