@@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute(
26362636 GGML_ASSERT (ncpsg % 32 == 0 );
26372637
26382638 // simdgroups per threadgroup (a.k.a. warps)
2639- // for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
26402639 const int64_t nsg = ne01 <= nqptg ? MAX (4 , MIN (ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 )) : 4 ;
26412640
2642- const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof (float )/2 );
2641+ const size_t smem = nqptg*(ne00 + 2 * nsg*(ncpsg + nqptg))*(sizeof (float )/2 );
26432642
26442643 // printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
26452644 GGML_ASSERT (smem <= ctx->device .maxThreadgroupMemoryLength );
@@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute(
26562655 GGML_ASSERT (ncpsg % 32 == 0 );
26572656
26582657 // simdgroups per threadgroup (a.k.a. warps)
2659- // for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
26602658 const int64_t nsgt = MAX (2 , MIN (ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 ));
26612659
26622660 int64_t nsg = 1 ;
@@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute(
26652663 }
26662664 nsg /= 2 ;
26672665
2668- // require power of 2
2669- // {
2670- // int64_t nsgm = 1;
2671- // while (nsgm < nsg) {
2672- // nsgm *= 2;
2673- // }
2674- // GGML_ASSERT(nsg == nsgm);
2675- // }
2676-
2677- const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof (float )/2 );
2666+ const size_t smem = (nqptg*(ne00 + 2 *nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof (float )/2 );
26782667
26792668 // printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
26802669 GGML_ASSERT (smem <= ctx->device .maxThreadgroupMemoryLength );
0 commit comments