@@ -1481,10 +1481,10 @@ static void ggml_metal_encode_node(
14811481                memcpy (&max, ((const  int32_t  *) dst->op_params ) + 1 , sizeof (float ));
14821482
14831483                [encoder setComputePipelineState: pipeline];
1484-                 [encoder setBuffer: id_src0    offset: offs_src0 atIndex: 0 ];
1485-                 [encoder setBuffer: id_dst     offset: offs_dst  atIndex: 1 ];
1486-                 [encoder setBytes: &min length: sizeof (min) atIndex: 2 ];
1487-                 [encoder setBytes: &max length: sizeof (max) atIndex: 3 ];
1484+                 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1485+                 [encoder setBuffer: id_dst  offset: offs_dst  atIndex: 1 ];
1486+                 [encoder setBytes: &min    length: sizeof (min) atIndex: 2 ];
1487+                 [encoder setBytes: &max    length: sizeof (max) atIndex: 3 ];
14881488
14891489                const  int64_t  n = ggml_nelements (dst);
14901490
@@ -1656,6 +1656,7 @@ static void ggml_metal_encode_node(
16561656
16571657                id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
16581658
1659+                 //  TODO: add ggml_metal_kargs struct
16591660                [encoder setComputePipelineState: pipeline];
16601661                [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
16611662                [encoder setBuffer: id_dst  offset: offs_dst  atIndex: 1 ];
@@ -1731,6 +1732,8 @@ static void ggml_metal_encode_node(
17311732                const  float  m0 = powf (2 .0f , -(max_bias       ) / n_head_log2);
17321733                const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
17331734
1735+                 //  TODO: add ggml_metal_kargs struct
1736+                 //  TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
17341737                [encoder setComputePipelineState: pipeline];
17351738                [encoder setBuffer: id_src0 offset: offs_src0   atIndex: 0 ];
17361739                if  (id_src1) {
@@ -1747,6 +1750,7 @@ static void ggml_metal_encode_node(
17471750                [encoder setBytes: &m0          length: sizeof (m0)          atIndex: 8 ];
17481751                [encoder setBytes: &m1          length: sizeof (m1)          atIndex: 9 ];
17491752                [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 10 ];
1753+ 
17501754                [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
17511755
17521756                [encoder dispatchThreadgroups: MTLSizeMake (ne01*ne02*ne03, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
@@ -1763,6 +1767,7 @@ static void ggml_metal_encode_node(
17631767                    pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline ;
17641768                }
17651769
1770+                 //  TODO: add ggml_metal_kargs struct
17661771                [encoder setComputePipelineState: pipeline];
17671772                [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
17681773                [encoder setBuffer: id_dst  offset: offs_dst  atIndex: 1 ];
@@ -1787,6 +1792,7 @@ static void ggml_metal_encode_node(
17871792
17881793                id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline ;
17891794
1795+                 //  TODO: add ggml_metal_kargs struct
17901796                [encoder setComputePipelineState: pipeline];
17911797                [encoder setBuffer: id_src0 offset: offs_src0    atIndex: 0 ];
17921798                [encoder setBuffer: id_src1 offset: offs_src1    atIndex: 1 ];
@@ -1857,6 +1863,7 @@ static void ggml_metal_encode_node(
18571863
18581864                id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline ;
18591865
1866+                 //  TODO: add ggml_metal_kargs struct
18601867                [encoder setComputePipelineState: pipeline];
18611868                [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
18621869                [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -2595,6 +2602,7 @@ static void ggml_metal_encode_node(
25952602                    default : GGML_ABORT (" not implemented"  );
25962603                }
25972604
2605+                 //  TODO: add ggml_metal_kargs struct
25982606                [encoder setComputePipelineState: pipeline];
25992607                [encoder setBuffer: id_src0     offset: offs_src0 atIndex: 0 ];
26002608                [encoder setBuffer: id_src1     offset: offs_src1 atIndex: 1 ];
@@ -2664,6 +2672,7 @@ static void ggml_metal_encode_node(
26642672
26652673                id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline ;
26662674
2675+                 //  TODO: add ggml_metal_kargs struct
26672676                [encoder setComputePipelineState: pipeline];
26682677                [encoder setBuffer: id_src0  offset: offs_src0        atIndex: 0 ];
26692678                [encoder setBuffer: id_dst   offset: offs_dst         atIndex: 1 ];
@@ -2853,6 +2862,7 @@ static void ggml_metal_encode_node(
28532862                    default : GGML_ABORT (" fatal error"  );
28542863                };
28552864
2865+                 //  TODO: add ggml_metal_kargs struct
28562866                [encoder setComputePipelineState: pipeline];
28572867                [encoder setBuffer: id_src1 offset: offs_src1       atIndex: 0 ];
28582868                [encoder setBuffer: id_dst  offset: offs_dst        atIndex: 1 ];
@@ -2893,6 +2903,7 @@ static void ggml_metal_encode_node(
28932903
28942904                const  id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline ;
28952905
2906+                 //  TODO: add ggml_metal_kargs struct
28962907                [encoder setComputePipelineState: pipeline];
28972908                [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
28982909                [encoder setBuffer: id_dst  offset: offs_dst  atIndex: 1 ];
@@ -2927,6 +2938,7 @@ static void ggml_metal_encode_node(
29272938
29282939                id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline ;
29292940
2941+                 //  TODO: add ggml_metal_kargs struct
29302942                [encoder setComputePipelineState: pipeline];
29312943                [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
29322944                [encoder setBuffer: id_dst  offset: offs_dst  atIndex: 1 ];
@@ -2963,6 +2975,7 @@ static void ggml_metal_encode_node(
29632975
29642976                id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline ;
29652977
2978+                 //  TODO: add ggml_metal_kargs struct
29662979                [encoder setComputePipelineState: pipeline];
29672980                [encoder setBuffer: id_dst  offset: offs_dst    atIndex: 0 ];
29682981                [encoder setBytes: &ne0   length: sizeof (ne0)   atIndex: 1 ];
@@ -2984,6 +2997,7 @@ static void ggml_metal_encode_node(
29842997
29852998                id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline ;
29862999
3000+                 //  TODO: add ggml_metal_kargs struct
29873001                [encoder setComputePipelineState: pipeline];
29883002                [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
29893003                [encoder setBuffer: id_dst  offset: offs_dst  atIndex: 1 ];
@@ -3022,6 +3036,7 @@ static void ggml_metal_encode_node(
30223036                    default : GGML_ABORT (" fatal error"  );
30233037                };
30243038
3039+                 //  TODO: add ggml_metal_kargs struct
30253040                [encoder setComputePipelineState: pipeline];
30263041                [encoder setBuffer: id_src0     offset: offs_src0        atIndex: 0 ];
30273042                [encoder setBuffer: id_dst      offset: offs_dst         atIndex: 1 ];
@@ -3040,6 +3055,7 @@ static void ggml_metal_encode_node(
30403055
30413056                id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline ;
30423057
3058+                 //  TODO: add ggml_metal_kargs struct
30433059                [encoder setComputePipelineState: pipeline];
30443060                [encoder setBuffer: id_src0 offset: offs_src0   atIndex: 0 ];
30453061                [encoder setBuffer: id_dst  offset: offs_dst    atIndex: 1 ];
@@ -3517,6 +3533,7 @@ static void ggml_metal_encode_node(
35173533                const  int64_t  n_threads = MIN ((int64_t )[pipeline maxTotalThreadsPerThreadgroup ], parallel_elements);
35183534                const  int64_t  n_tg = (parallel_elements + n_threads - 1 ) / n_threads;
35193535
3536+                 //  TODO: add ggml_metal_kargs struct
35203537                [encoder setComputePipelineState: pipeline];
35213538                [encoder setBuffer: id_src0 offset: offs_src0       atIndex: 0 ];
35223539                [encoder setBuffer: id_dst  offset: offs_dst        atIndex: 1 ];
0 commit comments