@@ -563,7 +563,9 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
563563    return  true ;
564564}
565565
566- static  id <MTLBuffer > ggml_metal_heap_alloc (struct  ggml_metal_heap * heap, size_t  size, size_t  alignment) {
566+ static  id <MTLBuffer > ggml_metal_heap_alloc (struct  ggml_metal_heap * heap, size_t  size) {
567+     const  size_t  alignment = 1024 *1024 ;
568+ 
567569    const  size_t  size_aligned = GGML_PAD (size, alignment);
568570
569571    heap->need  += size_aligned;
@@ -1583,7 +1585,8 @@ static bool ggml_metal_encode_node(
15831585                        ggml_backend_t    backend,
15841586                                   int    idx,
15851587          id <MTLComputeCommandEncoder >   encoder,
1586-                 struct  ggml_metal_heap * heap) {
1588+                 struct  ggml_metal_heap * heap,
1589+                                   bool    no_compute) {
15871590    struct  ggml_backend_metal_context        * ctx     = backend->context ;
15881591    struct  ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
15891592
@@ -1621,6 +1624,28 @@ static bool ggml_metal_encode_node(
16211624        GGML_ABORT (" unsupported op"  );
16221625    }
16231626
1627+     id <MTLBuffer > h_src0 = nil ;
1628+     switch  (dst->op ) {
1629+         case  GGML_OP_SOFT_MAX:
1630+             {
1631+                 h_src0 = ggml_metal_heap_alloc (heap, ggml_nbytes (src0));
1632+                 if  (!h_src0) {
1633+                     // GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
1634+                     //         __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
1635+                     return  false ;
1636+                 } else  {
1637+                     // GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
1638+                 }
1639+             } break ;
1640+         default :
1641+             {
1642+             } break ;
1643+     }
1644+ 
1645+     if  (no_compute) {
1646+         return  true ;
1647+     }
1648+ 
16241649    const  int64_t   ne00 = src0 ? src0->ne [0 ] : 0 ;
16251650    const  int64_t   ne01 = src0 ? src0->ne [1 ] : 0 ;
16261651    const  int64_t   ne02 = src0 ? src0->ne [2 ] : 0 ;
@@ -2278,23 +2303,14 @@ static bool ggml_metal_encode_node(
22782303                    /* .nb3  =*/   nb03,
22792304                };
22802305
2281-                 id <MTLBuffer > id_src0h = ggml_metal_heap_alloc (heap, ggml_nbytes (src0), 64 *1024 );
2282-                 if  (!id_src0h) {
2283-                     // GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
2284-                     //         __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
2285-                     return  true ;
2286-                 } else  {
2287-                     // GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
2288-                 }
2289- 
22902306                if  (src0->type  == GGML_TYPE_F16) {
22912307                    [encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
22922308                } else  {
22932309                    [encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
22942310                }
22952311                [encoder setBytes: &args_cpy length: sizeof (args_cpy) atIndex: 0 ];
22962312                [encoder setBuffer: id_src0  offset: offs_src0        atIndex: 1 ];
2297-                 [encoder setBuffer: id_src0h  offset: 0                 atIndex: 2 ];
2313+                 [encoder setBuffer: h_src0    offset: 0                 atIndex: 2 ];
22982314
22992315                GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
23002316                int  nth_cpy = MIN (1024 , ne00 / ggml_blck_size (src0->type ));
@@ -2315,11 +2331,11 @@ static bool ggml_metal_encode_node(
23152331                };
23162332
23172333                [encoder setComputePipelineState: pipeline];
2318-                 [encoder setBuffer: id_src0h  offset: 0             atIndex: 0 ];
2334+                 [encoder setBuffer: h_src0  offset: 0                atIndex: 0 ];
23192335                if  (id_src1) {
23202336                    [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
23212337                } else  {
2322-                     [encoder setBuffer: id_src0h  offset: 0         atIndex: 1 ];
2338+                     [encoder setBuffer: h_src0  offset: 0            atIndex: 1 ];
23232339                }
23242340                [encoder setBuffer: id_dst offset: offs_dst       atIndex: 2 ];
23252341                [encoder setBytes: &args   length: sizeof (args)   atIndex: 3 ];
@@ -4732,6 +4748,12 @@ static enum ggml_status ggml_metal_graph_compute(
47324748            }
47334749        }
47344750
4751+         for  (int  i = 0 ; i <= n_cb; ++i) {
4752+             struct  ggml_metal_heap * heap = ctx->cmd_bufs [i].heap ;
4753+ 
4754+             [heap->obj setPurgeableState: MTLPurgeableStateNonVolatile ];
4755+         }
4756+ 
47354757        //  the main thread commits the first few commands immediately
47364758        //  cmd_buf[n_cb]
47374759        {
@@ -4824,6 +4846,7 @@ static enum ggml_status ggml_metal_graph_compute(
48244846
48254847        if  (heap->fail  == 0 ) {
48264848            ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
4849+             [heap->obj setPurgeableState: MTLPurgeableStateEmpty ];
48274850
48284851            continue ;
48294852        }
@@ -5234,19 +5257,21 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
52345257
52355258        const  bool  should_capture = ctx->capture_next_compute ;
52365259
5260+         bool  no_compute = false ;
5261+ 
52375262        for  (int  idx = node_start; idx < node_end; ++idx) {
52385263            if  (should_capture) {
52395264                [encoder pushDebugGroup: [NSString  stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
52405265            }
52415266
5242-             const  bool  res = ggml_metal_encode_node (backend, idx, encoder, heap);
5267+             const  bool  res = ggml_metal_encode_node (backend, idx, encoder, heap, no_compute );
52435268
52445269            if  (should_capture) {
52455270                [encoder popDebugGroup ];
52465271            }
52475272
52485273            if  (!res) {
5249-                 break ;
5274+                 no_compute =  true ;
52505275            }
52515276        }
52525277
0 commit comments