@@ -1579,7 +1579,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
15791579 }
15801580}
15811581
1582- static void ggml_metal_encode_node (
1582+ static bool ggml_metal_encode_node (
15831583 ggml_backend_t backend,
15841584 int idx,
15851585 id <MTLComputeCommandEncoder > encoder,
@@ -1599,7 +1599,7 @@ static void ggml_metal_encode_node(
15991599 struct ggml_tensor * dst = node;
16001600
16011601 if (ggml_is_empty (dst)) {
1602- return ;
1602+ return true ;
16031603 }
16041604
16051605 switch (dst->op ) {
@@ -1610,7 +1610,7 @@ static void ggml_metal_encode_node(
16101610 case GGML_OP_PERMUTE:
16111611 {
16121612 // noop -> next node
1613- } return ;
1613+ } return true ;
16141614 default :
16151615 {
16161616 } break ;
@@ -2214,6 +2214,8 @@ static void ggml_metal_encode_node(
22142214 {
22152215 GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
22162216
2217+ GGML_ASSERT (ggml_is_contiguous (src0));
2218+
22172219 int nth = 32 ; // SIMD width
22182220
22192221 id <MTLComputePipelineState > pipeline = nil ;
@@ -2278,7 +2280,9 @@ static void ggml_metal_encode_node(
22782280
22792281 id <MTLBuffer > id_src0h = ggml_metal_heap_alloc (heap, ggml_nbytes (src0), 32 );
22802282 if (!id_src0h) {
2281- break ;
2283+ // GGML_LOG_ERROR("%s: failed to allocate buffer for cpy, size = %zu, need = %zu, max available = %zu\n",
2284+ // __func__, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:32]);
2285+ return false ;
22822286 }
22832287
22842288 if (src0->type == GGML_TYPE_F16) {
@@ -4669,6 +4673,8 @@ static void ggml_metal_encode_node(
46694673 GGML_ABORT (" fatal error" );
46704674 }
46714675 }
4676+
4677+ return true ;
46724678}
46734679
46744680static enum ggml_status ggml_metal_graph_compute (
@@ -4683,13 +4689,16 @@ static enum ggml_status ggml_metal_graph_compute(
46834689 // number of threads in addition to the main thread
46844690 const int n_cb = ctx->n_cb ;
46854691
4692+ int n_try = 64 ;
4693+
46864694 // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
46874695 // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
46884696 // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
46894697 // each thread creates it's own command buffer and enqueues the ops in parallel
46904698 //
46914699 // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
46924700
4701+ while (n_try-- > 0 ) {
46934702 @autoreleasepool {
46944703 ctx->gf = gf;
46954704
@@ -4752,8 +4761,6 @@ static enum ggml_status ggml_metal_graph_compute(
47524761 id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [n_cb].obj ;
47534762 [cmd_buf waitUntilCompleted ];
47544763
4755- ggml_metal_heap_reset (ctx->cmd_bufs [n_cb].heap );
4756-
47574764 MTLCommandBufferStatus status = [cmd_buf status ];
47584765 if (status != MTLCommandBufferStatusCompleted ) {
47594766 GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, n_cb, status);
@@ -4769,8 +4776,6 @@ static enum ggml_status ggml_metal_graph_compute(
47694776 id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [i].obj ;
47704777 [cmd_buf waitUntilCompleted ];
47714778
4772- ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
4773-
47744779 MTLCommandBufferStatus status = [cmd_buf status ];
47754780 if (status != MTLCommandBufferStatusCompleted ) {
47764781 GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, i, status);
@@ -4805,6 +4810,54 @@ static enum ggml_status ggml_metal_graph_compute(
48054810 }
48064811 }
48074812
4813+ bool retry = false ;
4814+
4815+ // check heap statuses
4816+ for (int i = 0 ; i <= n_cb; ++i) {
4817+ struct ggml_metal_heap * heap = ctx->cmd_bufs [i].heap ;
4818+
4819+ const size_t need = 4 *heap->need ;
4820+
4821+ // printf("\nXXXXXXXXXXXXXXXXX cb %d, need = %zu, fail = %d, size = %zu\n", i, need, heap->fail, [heap->obj currentAllocatedSize]);
4822+
4823+ if (heap->fail == 0 ) {
4824+ ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
4825+
4826+ continue ;
4827+ }
4828+
4829+ if (heap->fail == 2 ) {
4830+ GGML_LOG_ERROR (" %s : command buffer %d , MTLHeap ran out of buffers, max = %d \n " , __func__, i, heap->n );
4831+ return GGML_STATUS_ALLOC_FAILED;
4832+ }
4833+
4834+ if (heap->fail == 3 ) {
4835+ GGML_LOG_ERROR (" %s : command buffer %d , MTLHeap failed to allocate buffer, max = %d \n " , __func__, i, heap->n );
4836+ return GGML_STATUS_ALLOC_FAILED;
4837+ }
4838+
4839+ // GGML_LOG_INFO("%s: command buffer %d, MTLHeap need = %zu\n", __func__, i, need);
4840+
4841+ if (!ggml_metal_heap_resize (heap, need)) {
4842+ GGML_LOG_ERROR (" %s : failed to increase heap size to %zu \n " , __func__, need);
4843+ return GGML_STATUS_ALLOC_FAILED;
4844+ }
4845+
4846+ retry = true ;
4847+ }
4848+
4849+ if (!retry) {
4850+ break ;
4851+ }
4852+
4853+ // printf("XXXXXXXXXXXXXXXXXXXXXXX retry\n");
4854+
4855+ if (n_try == 0 ) {
4856+ GGML_LOG_ERROR (" %s : failed to allocate heap memory\n " , __func__);
4857+ return GGML_STATUS_ALLOC_FAILED;
4858+ }
4859+ }
4860+
48084861 return GGML_STATUS_SUCCESS;
48094862}
48104863
@@ -5167,64 +5220,36 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51675220 id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [cb_idx].obj ;
51685221 struct ggml_metal_heap * heap = ctx->cmd_bufs [cb_idx].heap ;
51695222
5170- int n_try = 2 ;
5171-
5172- while (n_try-- > 0 ) {
5173- id <MTLComputeCommandEncoder > encoder = [cmd_buf computeCommandEncoder ];
5174-
5175- int node_start = 0 ;
5176- int node_end = n_nodes_0;
5177-
5178- if (cb_idx < n_cb_l) {
5179- node_start = n_nodes_0 + ( (cb_idx + 0 ) * n_nodes_per_cb);
5180- node_end = n_nodes_0 + (MIN ((cb_idx == n_cb_l - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1));
5181- }
5182-
5183- const bool should_capture = ctx->capture_next_compute ;
5184-
5185- for (int idx = node_start; idx < node_end; ++idx) {
5186- if (should_capture) {
5187- [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
5188- }
5189-
5190- ggml_metal_encode_node (backend, idx, encoder, heap);
5223+ id <MTLComputeCommandEncoder > encoder = [cmd_buf computeCommandEncoder ];
51915224
5192- if (should_capture) {
5193- [encoder popDebugGroup ];
5194- }
5195- }
5225+ int node_start = 0 ;
5226+ int node_end = n_nodes_0;
51965227
5197- [encoder endEncoding ];
5228+ if (cb_idx < n_cb_l) {
5229+ node_start = n_nodes_0 + ( (cb_idx + 0 ) * n_nodes_per_cb);
5230+ node_end = n_nodes_0 + (MIN ((cb_idx == n_cb_l - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1));
5231+ }
51985232
5199- if (heap->fail == 0 ) {
5200- break ;
5201- }
5233+ const bool should_capture = ctx->capture_next_compute ;
52025234
5203- if (heap-> fail == 2 ) {
5204- GGML_LOG_ERROR ( " %s : MTLHeap ran out of buffers, max = %d \n " , __func__, heap-> n );
5205- break ;
5235+ for ( int idx = node_start; idx < node_end; ++idx ) {
5236+ if (should_capture) {
5237+ [encoder pushDebugGroup: [ NSString stringWithCString: ggml_op_desc ( ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]] ;
52065238 }
52075239
5208- if (heap->fail == 3 ) {
5209- GGML_LOG_ERROR (" %s : MTLHeap failed to allocate buffer\n " , __func__);
5210- break ;
5211- }
5240+ const bool res = ggml_metal_encode_node (backend, idx, encoder, heap);
52125241
5213- if (n_try == 0 ) {
5214- GGML_LOG_ERROR (" %s : failed to allocate heap memory\n " , __func__);
5215- break ;
5242+ if (should_capture) {
5243+ [encoder popDebugGroup ];
52165244 }
52175245
5218- const size_t need = heap->need ;
5219-
5220- GGML_LOG_INFO (" %s : increasing heap size to %zu \n " , __func__, need);
5221-
5222- if (!ggml_metal_heap_resize (heap, need)) {
5223- GGML_LOG_ERROR (" %s : failed to increase heap size to %zu \n " , __func__, need);
5246+ if (!res) {
52245247 break ;
52255248 }
52265249 }
52275250
5251+ [encoder endEncoding ];
5252+
52285253 if (cb_idx < 2 || ctx->abort_callback == NULL ) {
52295254 [cmd_buf commit ];
52305255 }
0 commit comments