@@ -202,6 +202,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202202 GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203203 GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204204 GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
205208 GGML_METAL_KERNEL_TYPE_RMS_NORM,
206209 GGML_METAL_KERNEL_TYPE_L2_NORM,
207210 GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1166,6 +1169,9 @@ @implementation GGMLMetalClass
11661169 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true );
11671170 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true );
11681171 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true );
1172+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true );
1173+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true );
1174+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
11691175 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
11701176 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
11711177 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1630,7 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16301636
16311637 if (!use_bfloat) {
16321638 for (size_t i = 0 , n = 3 ; i < n; ++i) {
1633- if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
1639+ if (op->src [i] != NULL && ( op->src [i]->type == GGML_TYPE_BF16 || op-> type == GGML_TYPE_BF16) ) {
16341640 return false ;
16351641 }
16361642 }
@@ -1798,6 +1804,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
17981804 {
17991805 return op->ne [3 ] == 1 ;
18001806 }
1807+ case GGML_OP_SET_ROWS:
1808+ {
1809+ return op->src [0 ]->type == GGML_TYPE_F32 && ggml_blck_size (op->type ) == 1 ; // tmp
1810+ }
18011811 default :
18021812 return false ;
18031813 }
@@ -3757,13 +3767,68 @@ static bool ggml_metal_encode_node(
37573767 };
37583768
37593769 [encoder setComputePipelineState: pipeline];
3760- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3761- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3762- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
3763- [encoder setBytes: &args length: sizeof (args) atIndex: 3 ];
3770+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3771+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3772+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3773+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
37643774
37653775 [encoder dispatchThreadgroups: MTLSizeMake (ne10, ne11, 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
37663776 } break ;
3777+ case GGML_OP_SET_ROWS:
3778+ {
3779+ id <MTLComputePipelineState > pipeline = nil ;
3780+
3781+ switch (dst->type ) {
3782+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline ; break ;
3783+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline ; break ;
3784+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline ; break ;
3785+ default : GGML_ABORT (" not implemented" );
3786+ }
3787+
3788+ const int32_t nk0 = ne0/ggml_blck_size (dst->type );
3789+
3790+ int nth = 32 ; // SIMD width
3791+
3792+ while (nth < nk0 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3793+ nth *= 2 ;
3794+ }
3795+
3796+ int nrptg = 1 ;
3797+ if (nth > nk0) {
3798+ nrptg = (nth + nk0 - 1 )/nk0;
3799+ nth = nk0;
3800+
3801+ if (nrptg*nth > (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3802+ nrptg--;
3803+ }
3804+ }
3805+
3806+ nth = MIN (nth, nk0);
3807+
3808+ ggml_metal_kargs_set_rows args = {
3809+ /* .nk0 =*/ nk0,
3810+ /* .ne01 =*/ ne01,
3811+ /* .nb01 =*/ nb01,
3812+ /* .nb02 =*/ nb02,
3813+ /* .nb03 =*/ nb03,
3814+ /* .ne11 =*/ ne11,
3815+ /* .ne12 =*/ ne12,
3816+ /* .nb10 =*/ nb10,
3817+ /* .nb11 =*/ nb11,
3818+ /* .nb12 =*/ nb12,
3819+ /* .nb1 =*/ nb1,
3820+ /* .nb2 =*/ nb2,
3821+ /* .nb3 =*/ nb3,
3822+ };
3823+
3824+ [encoder setComputePipelineState: pipeline];
3825+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3826+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3827+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3828+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
3829+
3830+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nrptg - 1 )/nrptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, nrptg, 1 )];
3831+ } break ;
37673832 case GGML_OP_RMS_NORM:
37683833 {
37693834 GGML_ASSERT (ne00 % 4 == 0 );
0 commit comments