@@ -1247,8 +1247,6 @@ static void ggml_metal_encode_node(
12471247
12481248 bool bcast_row = false ;
12491249
1250- int64_t nb = ne00; // used by the "row" kernels
1251-
12521250 id <MTLComputePipelineState > pipeline = nil ;
12531251
12541252 if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
@@ -1257,7 +1255,6 @@ static void ggml_metal_encode_node(
12571255 // src1 is a row
12581256 GGML_ASSERT (ne11 == 1 );
12591257
1260- nb = ne00 / 4 ;
12611258 switch (dst->op ) {
12621259 case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ; break ;
12631260 case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
@@ -1277,36 +1274,39 @@ static void ggml_metal_encode_node(
12771274 }
12781275 }
12791276
1277+ ggml_metal_kargs_bin args = {
1278+ /* .ne00 =*/ ne00,
1279+ /* .ne01 =*/ ne01,
1280+ /* .ne02 =*/ ne02,
1281+ /* .ne03 =*/ ne03,
1282+ /* .nb00 =*/ nb00,
1283+ /* .nb01 =*/ nb01,
1284+ /* .nb02 =*/ nb02,
1285+ /* .nb03 =*/ nb03,
1286+ /* .ne10 =*/ ne10,
1287+ /* .ne11 =*/ ne11,
1288+ /* .ne12 =*/ ne12,
1289+ /* .ne13 =*/ ne13,
1290+ /* .nb10 =*/ nb10,
1291+ /* .nb11 =*/ nb11,
1292+ /* .nb12 =*/ nb12,
1293+ /* .nb13 =*/ nb13,
1294+ /* .ne0 =*/ ne0,
1295+ /* .ne1 =*/ ne1,
1296+ /* .ne2 =*/ ne2,
1297+ /* .ne3 =*/ ne3,
1298+ /* .nb0 =*/ nb0,
1299+ /* .nb1 =*/ nb1,
1300+ /* .nb2 =*/ nb2,
1301+ /* .nb3 =*/ nb3,
1302+ /* .offs =*/ offs,
1303+ };
1304+
12801305 [encoder setComputePipelineState: pipeline];
1281- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1282- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1283- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1284- [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
1285- [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
1286- [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
1287- [encoder setBytes: &ne03 length: sizeof (ne03) atIndex: 6 ];
1288- [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 7 ];
1289- [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 8 ];
1290- [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 9 ];
1291- [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 10 ];
1292- [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 11 ];
1293- [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 12 ];
1294- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 13 ];
1295- [encoder setBytes: &ne13 length: sizeof (ne13) atIndex: 14 ];
1296- [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 15 ];
1297- [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 16 ];
1298- [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 17 ];
1299- [encoder setBytes: &nb13 length: sizeof (nb13) atIndex: 18 ];
1300- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 19 ];
1301- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 20 ];
1302- [encoder setBytes: &ne2 length: sizeof (ne2) atIndex: 21 ];
1303- [encoder setBytes: &ne3 length: sizeof (ne3) atIndex: 22 ];
1304- [encoder setBytes: &nb0 length: sizeof (nb0) atIndex: 23 ];
1305- [encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 24 ];
1306- [encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 25 ];
1307- [encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 26 ];
1308- [encoder setBytes: &offs length: sizeof (offs) atIndex: 27 ];
1309- [encoder setBytes: &nb length: sizeof (nb) atIndex: 28 ];
1306+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
1307+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
1308+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
1309+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
13101310
13111311 if (bcast_row) {
13121312 const int64_t n = ggml_nelements (dst)/4 ;
@@ -1404,35 +1404,39 @@ static void ggml_metal_encode_node(
14041404
14051405 const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ;
14061406
1407+ ggml_metal_kargs_bin args = {
1408+ /* .ne00 =*/ ne00,
1409+ /* .ne01 =*/ ne01,
1410+ /* .ne02 =*/ ne02,
1411+ /* .ne03 =*/ ne03,
1412+ /* .nb00 =*/ nb00,
1413+ /* .nb01 =*/ pnb1,
1414+ /* .nb02 =*/ pnb2,
1415+ /* .nb03 =*/ pnb3,
1416+ /* .ne10 =*/ ne10,
1417+ /* .ne11 =*/ ne11,
1418+ /* .ne12 =*/ ne12,
1419+ /* .ne13 =*/ ne13,
1420+ /* .nb10 =*/ nb10,
1421+ /* .nb11 =*/ nb11,
1422+ /* .nb12 =*/ nb12,
1423+ /* .nb13 =*/ nb13,
1424+ /* .ne0 =*/ ne0,
1425+ /* .ne1 =*/ ne1,
1426+ /* .ne2 =*/ ne2,
1427+ /* .ne3 =*/ ne3,
1428+ /* .nb0 =*/ nb0,
1429+ /* .nb1 =*/ pnb1,
1430+ /* .nb2 =*/ pnb2,
1431+ /* .nb3 =*/ pnb3,
1432+ /* .offs =*/ offs,
1433+ };
1434+
14071435 [encoder setComputePipelineState: pipeline];
1408- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1409- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1410- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1411- [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
1412- [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
1413- [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
1414- [encoder setBytes: &ne03 length: sizeof (ne03) atIndex: 6 ];
1415- [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 7 ];
1416- [encoder setBytes: &pnb1 length: sizeof (pnb1) atIndex: 8 ];
1417- [encoder setBytes: &pnb2 length: sizeof (pnb2) atIndex: 9 ];
1418- [encoder setBytes: &pnb3 length: sizeof (pnb3) atIndex: 10 ];
1419- [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 11 ];
1420- [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 12 ];
1421- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 13 ];
1422- [encoder setBytes: &ne13 length: sizeof (ne13) atIndex: 14 ];
1423- [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 15 ];
1424- [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 16 ];
1425- [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 17 ];
1426- [encoder setBytes: &nb13 length: sizeof (nb13) atIndex: 18 ];
1427- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 19 ];
1428- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 20 ];
1429- [encoder setBytes: &ne2 length: sizeof (ne2) atIndex: 21 ];
1430- [encoder setBytes: &ne3 length: sizeof (ne3) atIndex: 22 ];
1431- [encoder setBytes: &nb0 length: sizeof (nb0) atIndex: 23 ];
1432- [encoder setBytes: &pnb1 length: sizeof (pnb1) atIndex: 24 ];
1433- [encoder setBytes: &pnb2 length: sizeof (pnb2) atIndex: 25 ];
1434- [encoder setBytes: &pnb3 length: sizeof (pnb3) atIndex: 26 ];
1435- [encoder setBytes: &offs length: sizeof (offs) atIndex: 27 ];
1436+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
1437+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
1438+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
1439+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
14361440
14371441 const int nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne00);
14381442
0 commit comments