@@ -640,7 +640,7 @@ @implementation ggml_metal_heap_ptr
640640@end
641641
642642//
643- // ggml_metal_mem_pool
643+ // ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
644644//
645645
646646struct ggml_metal_mem_pool {
@@ -4112,6 +4112,14 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41124112 default : break ;
41134113 }
41144114
4115+ // TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
4116+ // reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
4117+ // so we add this extra barrier to prevent the race.
4118+ // the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
4119+ if (ctx_dev->use_concurrency ) {
4120+ ggml_metal_encode_mem_ranges_reset (ctx_enc);
4121+ }
4122+
41154123 // tokens per expert
41164124 const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
41174125 id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
@@ -4172,6 +4180,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41724180 [encoder dispatchThreadgroups: MTLSizeMake (1 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (ne02, 1 , 1 )];
41734181 }
41744182
4183+ // this barrier is always needed because the next kernel has to wait for the id maps to be computed
41754184 if (ctx_dev->use_concurrency ) {
41764185 ggml_metal_encode_mem_ranges_reset (ctx_enc);
41774186 }
@@ -5561,6 +5570,12 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55615570 GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
55625571 GGML_ASSERT (ne1*ne2*ne3 <= (1u << 31 ));
55635572
5573+ // using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
5574+ // still, we assume that concurrent FA won't happen before we do the refactor
5575+ // if (ctx_dev->use_concurrency) {
5576+ // ggml_metal_encode_mem_ranges_reset(ctx_enc);
5577+ // }
5578+
55645579 const int32_t nrows = ne1*ne2*ne3;
55655580
55665581 // temp buffer for writing the results from each workgroup
@@ -5939,6 +5954,7 @@ static enum ggml_status ggml_metal_graph_compute(
59395954 // cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
59405955 // TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
59415956 // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
5957+ // [TAG_MEM_POOL_REMOVE]
59425958 // id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
59435959 id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBuffer ];
59445960 [cmd_buf retain ];
0 commit comments