@@ -563,7 +563,9 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
563563 return true ;
564564}
565565
566- static id <MTLBuffer > ggml_metal_heap_alloc (struct ggml_metal_heap * heap, size_t size, size_t alignment) {
566+ static id <MTLBuffer > ggml_metal_heap_alloc (struct ggml_metal_heap * heap, size_t size) {
567+ const size_t alignment = 1024 *1024 ;
568+
567569 const size_t size_aligned = GGML_PAD (size, alignment);
568570
569571 heap->need += size_aligned;
@@ -1582,7 +1584,8 @@ static bool ggml_metal_encode_node(
15821584 ggml_backend_t backend,
15831585 int idx,
15841586 id <MTLComputeCommandEncoder > encoder,
1585- struct ggml_metal_heap * heap) {
1587+ struct ggml_metal_heap * heap,
1588+ bool no_compute) {
15861589 struct ggml_backend_metal_context * ctx = backend->context ;
15871590 struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
15881591
@@ -1620,6 +1623,28 @@ static bool ggml_metal_encode_node(
16201623 GGML_ABORT (" unsupported op" );
16211624 }
16221625
1626+ id <MTLBuffer > h_src0 = nil ;
1627+ switch (dst->op ) {
1628+ case GGML_OP_SOFT_MAX:
1629+ {
1630+ h_src0 = ggml_metal_heap_alloc (heap, ggml_nbytes (src0));
1631+ if (!h_src0) {
1632+ // GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
1633+ // __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
1634+ return false ;
1635+ } else {
1636+ // GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
1637+ }
1638+ } break ;
1639+ default :
1640+ {
1641+ } break ;
1642+ }
1643+
1644+ if (no_compute) {
1645+ return true ;
1646+ }
1647+
16231648 const int64_t ne00 = src0 ? src0->ne [0 ] : 0 ;
16241649 const int64_t ne01 = src0 ? src0->ne [1 ] : 0 ;
16251650 const int64_t ne02 = src0 ? src0->ne [2 ] : 0 ;
@@ -2277,23 +2302,14 @@ static bool ggml_metal_encode_node(
22772302 /* .nb3 =*/ nb03,
22782303 };
22792304
2280- id <MTLBuffer > id_src0h = ggml_metal_heap_alloc (heap, ggml_nbytes (src0), 64 *1024 );
2281- if (!id_src0h) {
2282- // GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
2283- // __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
2284- return true ;
2285- } else {
2286- // GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
2287- }
2288-
22892305 if (src0->type == GGML_TYPE_F16) {
22902306 [encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
22912307 } else {
22922308 [encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
22932309 }
22942310 [encoder setBytes: &args_cpy length: sizeof (args_cpy) atIndex: 0 ];
22952311 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2296- [encoder setBuffer: id_src0h offset: 0 atIndex: 2 ];
2312+ [encoder setBuffer: h_src0 offset: 0 atIndex: 2 ];
22972313
22982314 GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
22992315 int nth_cpy = MIN (1024 , ne00 / ggml_blck_size (src0->type ));
@@ -2314,11 +2330,11 @@ static bool ggml_metal_encode_node(
23142330 };
23152331
23162332 [encoder setComputePipelineState: pipeline];
2317- [encoder setBuffer: id_src0h offset: 0 atIndex: 0 ];
2333+ [encoder setBuffer: h_src0 offset: 0 atIndex: 0 ];
23182334 if (id_src1) {
23192335 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
23202336 } else {
2321- [encoder setBuffer: id_src0h offset: 0 atIndex: 1 ];
2337+ [encoder setBuffer: h_src0 offset: 0 atIndex: 1 ];
23222338 }
23232339 [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
23242340 [encoder setBytes: &args length: sizeof (args) atIndex: 3 ];
@@ -4731,6 +4747,12 @@ static enum ggml_status ggml_metal_graph_compute(
47314747 }
47324748 }
47334749
4750+ for (int i = 0 ; i <= n_cb; ++i) {
4751+ struct ggml_metal_heap * heap = ctx->cmd_bufs [i].heap ;
4752+
4753+ [heap->obj setPurgeableState: MTLPurgeableStateNonVolatile ];
4754+ }
4755+
47344756 // the main thread commits the first few commands immediately
47354757 // cmd_buf[n_cb]
47364758 {
@@ -4823,6 +4845,7 @@ static enum ggml_status ggml_metal_graph_compute(
48234845
48244846 if (heap->fail == 0 ) {
48254847 ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
4848+ [heap->obj setPurgeableState: MTLPurgeableStateEmpty ];
48264849
48274850 continue ;
48284851 }
@@ -5233,19 +5256,21 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
52335256
52345257 const bool should_capture = ctx->capture_next_compute ;
52355258
5259+ bool no_compute = false ;
5260+
52365261 for (int idx = node_start; idx < node_end; ++idx) {
52375262 if (should_capture) {
52385263 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
52395264 }
52405265
5241- const bool res = ggml_metal_encode_node (backend, idx, encoder, heap);
5266+ const bool res = ggml_metal_encode_node (backend, idx, encoder, heap, no_compute );
52425267
52435268 if (should_capture) {
52445269 [encoder popDebugGroup ];
52455270 }
52465271
52475272 if (!res) {
5248- break ;
5273+ no_compute = true ;
52495274 }
52505275 }
52515276
0 commit comments