@@ -1381,25 +1381,29 @@ static void ggml_metal_encode_node(
13811381
13821382 const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline ;
13831383
1384+ ggml_metal_kargs_cpy args = {
1385+ /* .ne00 =*/ ne00,
1386+ /* .ne01 =*/ ne01,
1387+ /* .ne02 =*/ ne02,
1388+ /* .ne03 =*/ ne03,
1389+ /* .nb00 =*/ nb00,
1390+ /* .nb01 =*/ nb01,
1391+ /* .nb02 =*/ nb02,
1392+ /* .nb03 =*/ nb03,
1393+ /* .ne0 =*/ ne0,
1394+ /* .ne1 =*/ ne1,
1395+ /* .ne2 =*/ ne2,
1396+ /* .ne3 =*/ ne3,
1397+ /* .nb0 =*/ nb0,
1398+ /* .nb1 =*/ nb1,
1399+ /* .nb2 =*/ nb2,
1400+ /* .nb3 =*/ nb3,
1401+ };
1402+
13841403 [encoder setComputePipelineState: pipeline];
1385- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1386- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1387- [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1388- [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
1389- [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
1390- [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 5 ];
1391- [encoder setBytes: &nb00 length: sizeof (uint64_t ) atIndex: 6 ];
1392- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 7 ];
1393- [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 8 ];
1394- [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 9 ];
1395- [encoder setBytes: &ne0 length: sizeof ( int64_t ) atIndex: 10 ];
1396- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 11 ];
1397- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 12 ];
1398- [encoder setBytes: &ne3 length: sizeof ( int64_t ) atIndex: 13 ];
1399- [encoder setBytes: &nb0 length: sizeof (uint64_t ) atIndex: 14 ];
1400- [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 15 ];
1401- [encoder setBytes: &nb2 length: sizeof (uint64_t ) atIndex: 16 ];
1402- [encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 17 ];
1404+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
1405+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
1406+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
14031407
14041408 const int nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne00);
14051409
@@ -3429,25 +3433,29 @@ static void ggml_metal_encode_node(
34293433 default : GGML_ABORT (" not implemented" );
34303434 }
34313435
3436+ ggml_metal_kargs_cpy args = {
3437+ /* .ne00 =*/ ne00,
3438+ /* .ne01 =*/ ne01,
3439+ /* .ne02 =*/ ne02,
3440+ /* .ne03 =*/ ne03,
3441+ /* .nb00 =*/ nb00,
3442+ /* .nb01 =*/ nb01,
3443+ /* .nb02 =*/ nb02,
3444+ /* .nb03 =*/ nb03,
3445+ /* .ne0 =*/ ne0,
3446+ /* .ne1 =*/ ne1,
3447+ /* .ne2 =*/ ne2,
3448+ /* .ne3 =*/ ne3,
3449+ /* .nb0 =*/ nb0,
3450+ /* .nb1 =*/ nb1,
3451+ /* .nb2 =*/ nb2,
3452+ /* .nb3 =*/ nb3,
3453+ };
3454+
34323455 [encoder setComputePipelineState: pipeline];
3433- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3434- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
3435- [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
3436- [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
3437- [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
3438- [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 5 ];
3439- [encoder setBytes: &nb00 length: sizeof (uint64_t ) atIndex: 6 ];
3440- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 7 ];
3441- [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 8 ];
3442- [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 9 ];
3443- [encoder setBytes: &ne0 length: sizeof ( int64_t ) atIndex: 10 ];
3444- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 11 ];
3445- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 12 ];
3446- [encoder setBytes: &ne3 length: sizeof ( int64_t ) atIndex: 13 ];
3447- [encoder setBytes: &nb0 length: sizeof (uint64_t ) atIndex: 14 ];
3448- [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 15 ];
3449- [encoder setBytes: &nb2 length: sizeof (uint64_t ) atIndex: 16 ];
3450- [encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 17 ];
3456+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3457+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3458+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
34513459
34523460 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
34533461 } break ;
0 commit comments