@@ -1481,10 +1481,10 @@ static void ggml_metal_encode_node(
14811481 memcpy (&max, ((const int32_t *) dst->op_params ) + 1 , sizeof (float ));
14821482
14831483 [encoder setComputePipelineState: pipeline];
1484- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1485- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1486- [encoder setBytes: &min length: sizeof (min) atIndex: 2 ];
1487- [encoder setBytes: &max length: sizeof (max) atIndex: 3 ];
1484+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1485+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1486+ [encoder setBytes: &min length: sizeof (min) atIndex: 2 ];
1487+ [encoder setBytes: &max length: sizeof (max) atIndex: 3 ];
14881488
14891489 const int64_t n = ggml_nelements (dst);
14901490
@@ -1656,6 +1656,7 @@ static void ggml_metal_encode_node(
16561656
16571657 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
16581658
1659+ // TODO: add ggml_metal_kargs struct
16591660 [encoder setComputePipelineState: pipeline];
16601661 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
16611662 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -1731,6 +1732,8 @@ static void ggml_metal_encode_node(
17311732 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
17321733 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
17331734
1735+ // TODO: add ggml_metal_kargs struct
1736+ // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
17341737 [encoder setComputePipelineState: pipeline];
17351738 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
17361739 if (id_src1) {
@@ -1747,6 +1750,7 @@ static void ggml_metal_encode_node(
17471750 [encoder setBytes: &m0 length: sizeof (m0) atIndex: 8 ];
17481751 [encoder setBytes: &m1 length: sizeof (m1) atIndex: 9 ];
17491752 [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 10 ];
1753+
17501754 [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
17511755
17521756 [encoder dispatchThreadgroups: MTLSizeMake (ne01*ne02*ne03, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
@@ -1763,6 +1767,7 @@ static void ggml_metal_encode_node(
17631767 pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline ;
17641768 }
17651769
1770+ // TODO: add ggml_metal_kargs struct
17661771 [encoder setComputePipelineState: pipeline];
17671772 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
17681773 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -1787,6 +1792,7 @@ static void ggml_metal_encode_node(
17871792
17881793 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline ;
17891794
1795+ // TODO: add ggml_metal_kargs struct
17901796 [encoder setComputePipelineState: pipeline];
17911797 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
17921798 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -1857,6 +1863,7 @@ static void ggml_metal_encode_node(
18571863
18581864 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline ;
18591865
1866+ // TODO: add ggml_metal_kargs struct
18601867 [encoder setComputePipelineState: pipeline];
18611868 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
18621869 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -2595,6 +2602,7 @@ static void ggml_metal_encode_node(
25952602 default : GGML_ABORT (" not implemented" );
25962603 }
25972604
2605+ // TODO: add ggml_metal_kargs struct
25982606 [encoder setComputePipelineState: pipeline];
25992607 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
26002608 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -2664,6 +2672,7 @@ static void ggml_metal_encode_node(
26642672
26652673 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline ;
26662674
2675+ // TODO: add ggml_metal_kargs struct
26672676 [encoder setComputePipelineState: pipeline];
26682677 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
26692678 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2853,6 +2862,7 @@ static void ggml_metal_encode_node(
28532862 default : GGML_ABORT (" fatal error" );
28542863 };
28552864
2865+ // TODO: add ggml_metal_kargs struct
28562866 [encoder setComputePipelineState: pipeline];
28572867 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
28582868 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2893,6 +2903,7 @@ static void ggml_metal_encode_node(
28932903
28942904 const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline ;
28952905
2906+ // TODO: add ggml_metal_kargs struct
28962907 [encoder setComputePipelineState: pipeline];
28972908 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
28982909 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2927,6 +2938,7 @@ static void ggml_metal_encode_node(
29272938
29282939 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline ;
29292940
2941+ // TODO: add ggml_metal_kargs struct
29302942 [encoder setComputePipelineState: pipeline];
29312943 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
29322944 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2963,6 +2975,7 @@ static void ggml_metal_encode_node(
29632975
29642976 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline ;
29652977
2978+ // TODO: add ggml_metal_kargs struct
29662979 [encoder setComputePipelineState: pipeline];
29672980 [encoder setBuffer: id_dst offset: offs_dst atIndex: 0 ];
29682981 [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 1 ];
@@ -2984,6 +2997,7 @@ static void ggml_metal_encode_node(
29842997
29852998 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline ;
29862999
3000+ // TODO: add ggml_metal_kargs struct
29873001 [encoder setComputePipelineState: pipeline];
29883002 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
29893003 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3022,6 +3036,7 @@ static void ggml_metal_encode_node(
30223036 default : GGML_ABORT (" fatal error" );
30233037 };
30243038
3039+ // TODO: add ggml_metal_kargs struct
30253040 [encoder setComputePipelineState: pipeline];
30263041 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
30273042 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3040,6 +3055,7 @@ static void ggml_metal_encode_node(
30403055
30413056 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline ;
30423057
3058+ // TODO: add ggml_metal_kargs struct
30433059 [encoder setComputePipelineState: pipeline];
30443060 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
30453061 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3517,6 +3533,7 @@ static void ggml_metal_encode_node(
35173533 const int64_t n_threads = MIN ((int64_t )[pipeline maxTotalThreadsPerThreadgroup ], parallel_elements);
35183534 const int64_t n_tg = (parallel_elements + n_threads - 1 ) / n_threads;
35193535
3536+ // TODO: add ggml_metal_kargs struct
35203537 [encoder setComputePipelineState: pipeline];
35213538 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
35223539 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
0 commit comments