@@ -1485,10 +1485,10 @@ static void ggml_metal_encode_node(
14851485 memcpy (&max, ((const int32_t *) dst->op_params ) + 1 , sizeof (float ));
14861486
14871487 [encoder setComputePipelineState: pipeline];
1488- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1489- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1490- [encoder setBytes: &min length: sizeof (min) atIndex: 2 ];
1491- [encoder setBytes: &max length: sizeof (max) atIndex: 3 ];
1488+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1489+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1490+ [encoder setBytes: &min length: sizeof (min) atIndex: 2 ];
1491+ [encoder setBytes: &max length: sizeof (max) atIndex: 3 ];
14921492
14931493 const int64_t n = ggml_nelements (dst);
14941494
@@ -1660,6 +1660,7 @@ static void ggml_metal_encode_node(
16601660
16611661 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
16621662
1663+ // TODO: add ggml_metal_kargs struct
16631664 [encoder setComputePipelineState: pipeline];
16641665 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
16651666 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -1735,6 +1736,8 @@ static void ggml_metal_encode_node(
17351736 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
17361737 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
17371738
1739+ // TODO: add ggml_metal_kargs struct
1740+ // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
17381741 [encoder setComputePipelineState: pipeline];
17391742 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
17401743 if (id_src1) {
@@ -1751,6 +1754,7 @@ static void ggml_metal_encode_node(
17511754 [encoder setBytes: &m0 length: sizeof (m0) atIndex: 8 ];
17521755 [encoder setBytes: &m1 length: sizeof (m1) atIndex: 9 ];
17531756 [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 10 ];
1757+
17541758 [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
17551759
17561760 [encoder dispatchThreadgroups: MTLSizeMake (ne01*ne02*ne03, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
@@ -1767,6 +1771,7 @@ static void ggml_metal_encode_node(
17671771 pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline ;
17681772 }
17691773
1774+ // TODO: add ggml_metal_kargs struct
17701775 [encoder setComputePipelineState: pipeline];
17711776 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
17721777 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -1791,6 +1796,7 @@ static void ggml_metal_encode_node(
17911796
17921797 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline ;
17931798
1799+ // TODO: add ggml_metal_kargs struct
17941800 [encoder setComputePipelineState: pipeline];
17951801 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
17961802 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -1861,6 +1867,7 @@ static void ggml_metal_encode_node(
18611867
18621868 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline ;
18631869
1870+ // TODO: add ggml_metal_kargs struct
18641871 [encoder setComputePipelineState: pipeline];
18651872 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
18661873 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -2599,6 +2606,7 @@ static void ggml_metal_encode_node(
25992606 default : GGML_ABORT (" not implemented" );
26002607 }
26012608
2609+ // TODO: add ggml_metal_kargs struct
26022610 [encoder setComputePipelineState: pipeline];
26032611 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
26042612 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -2668,6 +2676,7 @@ static void ggml_metal_encode_node(
26682676
26692677 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline ;
26702678
2679+ // TODO: add ggml_metal_kargs struct
26712680 [encoder setComputePipelineState: pipeline];
26722681 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
26732682 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2857,6 +2866,7 @@ static void ggml_metal_encode_node(
28572866 default : GGML_ABORT (" fatal error" );
28582867 };
28592868
2869+ // TODO: add ggml_metal_kargs struct
28602870 [encoder setComputePipelineState: pipeline];
28612871 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
28622872 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2897,6 +2907,7 @@ static void ggml_metal_encode_node(
28972907
28982908 const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline ;
28992909
2910+ // TODO: add ggml_metal_kargs struct
29002911 [encoder setComputePipelineState: pipeline];
29012912 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
29022913 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2931,6 +2942,7 @@ static void ggml_metal_encode_node(
29312942
29322943 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline ;
29332944
2945+ // TODO: add ggml_metal_kargs struct
29342946 [encoder setComputePipelineState: pipeline];
29352947 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
29362948 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2967,6 +2979,7 @@ static void ggml_metal_encode_node(
29672979
29682980 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline ;
29692981
2982+ // TODO: add ggml_metal_kargs struct
29702983 [encoder setComputePipelineState: pipeline];
29712984 [encoder setBuffer: id_dst offset: offs_dst atIndex: 0 ];
29722985 [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 1 ];
@@ -2988,6 +3001,7 @@ static void ggml_metal_encode_node(
29883001
29893002 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline ;
29903003
3004+ // TODO: add ggml_metal_kargs struct
29913005 [encoder setComputePipelineState: pipeline];
29923006 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
29933007 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3026,6 +3040,7 @@ static void ggml_metal_encode_node(
30263040 default : GGML_ABORT (" fatal error" );
30273041 };
30283042
3043+ // TODO: add ggml_metal_kargs struct
30293044 [encoder setComputePipelineState: pipeline];
30303045 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
30313046 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3044,6 +3059,7 @@ static void ggml_metal_encode_node(
30443059
30453060 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline ;
30463061
3062+ // TODO: add ggml_metal_kargs struct
30473063 [encoder setComputePipelineState: pipeline];
30483064 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
30493065 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3521,6 +3537,7 @@ static void ggml_metal_encode_node(
35213537 const int64_t n_threads = MIN ((int64_t )[pipeline maxTotalThreadsPerThreadgroup ], parallel_elements);
35223538 const int64_t n_tg = (parallel_elements + n_threads - 1 ) / n_threads;
35233539
3540+ // TODO: add ggml_metal_kargs struct
35243541 [encoder setComputePipelineState: pipeline];
35253542 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
35263543 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
0 commit comments