@@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
24502450 nth *= 2 ;
24512451 }
24522452
2453+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
24532454 nth = MIN (nth, ne00);
24542455
24552456 ggml_metal_kargs_sum_rows args = {
@@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
37803781 nth *= 2 ;
37813782 }
37823783
3784+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
37833785 nth = MIN (nth, ne00/4 );
37843786
37853787 ggml_metal_kargs_rms_norm args = {
@@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
38163818 nth *= 2 ;
38173819 }
38183820
3821+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
38193822 nth = MIN (nth, ne00/4 );
38203823
38213824 ggml_metal_kargs_l2_norm args = {
@@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
38883891 nth *= 2 ;
38893892 }
38903893
3894+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
38913895 nth = MIN (nth, ne00/4 );
38923896
38933897 ggml_metal_kargs_norm args = {
@@ -4974,8 +4978,39 @@ static bool ggml_metal_encode_node(
49744978 default : GGML_ABORT (" not implemented" );
49754979 }
49764980
4981+ GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
4982+
4983+ // TODO: support
4984+ // const int32_t nk00 = ne00/ggml_blck_size(dst->type);
4985+ const int32_t nk00 = ne00;
4986+
4987+ int nth = 32 ; // SIMD width
4988+
4989+ while (nth < nk00 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
4990+ nth *= 2 ;
4991+ }
4992+
4993+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
4994+
4995+ // when rows are small, we can batch them together in a single threadgroup
4996+ int nrptg = 1 ;
4997+
4998+ // TODO: relax this constraint in the future
4999+ if (ggml_blck_size (src0->type ) == 1 && ggml_blck_size (dst->type ) == 1 ) {
5000+ if (nth > nk00) {
5001+ nrptg = (nth + nk00 - 1 )/nk00;
5002+ nth = nk00;
5003+
5004+ if (nrptg*nth > (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
5005+ nrptg--;
5006+ }
5007+ }
5008+ }
5009+
5010+ nth = MIN (nth, nk00);
5011+
49775012 ggml_metal_kargs_cpy args = {
4978- /* .ne00 =*/ ne00 ,
5013+ /* .ne00 =*/ nk00 ,
49795014 /* .ne01 =*/ ne01,
49805015 /* .ne02 =*/ ne02,
49815016 /* .ne03 =*/ ne03,
@@ -4998,11 +5033,7 @@ static bool ggml_metal_encode_node(
49985033 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
49995034 [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
50005035
5001- GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
5002- int nth = MIN (1024 , ne00/ggml_blck_size (src0->type ));
5003-
5004- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
5005-
5036+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nrptg - 1 )/nrptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, nrptg, 1 )];
50065037 } break ;
50075038 case GGML_OP_SET:
50085039 {
0 commit comments