@@ -4974,8 +4974,37 @@ static bool ggml_metal_encode_node(
49744974 default : GGML_ABORT (" not implemented" );
49754975 }
49764976
4977+ GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
4978+
4979+ // TODO: support
4980+ // const int32_t nk00 = ne00/ggml_blck_size(dst->type);
4981+ const int32_t nk00 = ne00;
4982+
4983+ int nth = 32 ; // SIMD width
4984+
4985+ while (nth < nk00 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
4986+ nth *= 2 ;
4987+ }
4988+
4989+ // when rows are small, we can batch them together in a single threadgroup
4990+ int nrptg = 1 ;
4991+
4992+ // TODO: relax this constraint in the future
4993+ if (ggml_blck_size (src0->type ) == 1 && ggml_blck_size (dst->type ) == 1 ) {
4994+ if (nth > nk00) {
4995+ nrptg = (nth + nk00 - 1 )/nk00;
4996+ nth = nk00;
4997+
4998+ if (nrptg*nth > (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
4999+ nrptg--;
5000+ }
5001+ }
5002+ }
5003+
5004+ nth = MIN (nth, nk00);
5005+
49775006 ggml_metal_kargs_cpy args = {
4978- /* .ne00 =*/ ne00 ,
5007+ /* .ne00 =*/ nk00 ,
49795008 /* .ne01 =*/ ne01,
49805009 /* .ne02 =*/ ne02,
49815010 /* .ne03 =*/ ne03,
@@ -4998,11 +5027,7 @@ static bool ggml_metal_encode_node(
49985027 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
49995028 [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
50005029
5001- GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
5002- int nth = MIN (1024 , ne00/ggml_blck_size (src0->type ));
5003-
5004- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
5005-
5030+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nrptg - 1 )/nrptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, nrptg, 1 )];
50065031 } break ;
50075032 case GGML_OP_SET:
50085033 {
0 commit comments