@@ -232,28 +232,6 @@ - (void) dealloc {
232232@end
233233
234234enum ggml_metal_kernel_type {
235- GGML_METAL_KERNEL_TYPE_ADD,
236- GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
237- GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
238- GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
239- GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
240- GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
241- GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
242- GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
243- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
244- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
245- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
246- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
247- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
248- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
249- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
250- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
251- GGML_METAL_KERNEL_TYPE_SUB,
252- GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
253- GGML_METAL_KERNEL_TYPE_MUL,
254- GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
255- GGML_METAL_KERNEL_TYPE_DIV,
256- GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
257235 GGML_METAL_KERNEL_TYPE_ADD_ID,
258236 GGML_METAL_KERNEL_TYPE_REPEAT_F32,
259237 GGML_METAL_KERNEL_TYPE_REPEAT_F16,
@@ -319,9 +297,6 @@ - (void) dealloc {
319297 GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
320298 GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
321299 GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
322- GGML_METAL_KERNEL_TYPE_RMS_NORM,
323- GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
324- GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
325300 GGML_METAL_KERNEL_TYPE_L2_NORM,
326301 GGML_METAL_KERNEL_TYPE_GROUP_NORM,
327302 GGML_METAL_KERNEL_TYPE_NORM,
@@ -1177,28 +1152,6 @@ @implementation GGMLMetalClass
11771152
11781153 // simd_sum and simd_max requires MTLGPUFamilyApple7
11791154
1180- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD, add, true );
1181- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true );
1182- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true );
1183- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true );
1184- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true );
1185- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
1186- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true );
1187- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
1188- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true );
1189- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true );
1190- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true );
1191- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true );
1192- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true );
1193- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true );
1194- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true );
1195- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true );
1196- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
1197- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true );
1198- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
1199- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true );
1200- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV, div, true );
1201- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true );
12021155 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true );
12031156 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true );
12041157 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true );
@@ -1264,9 +1217,6 @@ @implementation GGMLMetalClass
12641217 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true );
12651218 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true );
12661219 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
1267- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1268- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1269- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
12701220 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
12711221 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
12721222 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
@@ -1722,6 +1672,73 @@ @implementation GGMLMetalClass
17221672 GGML_UNUSED (op);
17231673}
17241674
1675+ static id <MTLComputePipelineState > ggml_metal_get_pipeline_bin (
1676+ ggml_backend_t backend, enum ggml_op op,
1677+ int32_t n_fuse,
1678+ bool row) {
1679+ struct ggml_backend_metal_context * ctx = backend->context ;
1680+
1681+ char base[256 ];
1682+ char name[256 ];
1683+
1684+ @autoreleasepool {
1685+ const char * op_str = " undefined" ;
1686+ switch (op) {
1687+ case GGML_OP_ADD: op_str = " add" ; break ;
1688+ case GGML_OP_SUB: op_str = " sub" ; break ;
1689+ case GGML_OP_MUL: op_str = " mul" ; break ;
1690+ case GGML_OP_DIV: op_str = " div" ; break ;
1691+ default : GGML_ABORT (" fatal error" );
1692+ };
1693+
1694+ if (row) {
1695+ snprintf (base, 256 , " kernel_%s _row_c4_fuse_%d " , op_str, n_fuse);
1696+ } else {
1697+ snprintf (base, 256 , " kernel_%s _fuse_%d " , op_str, n_fuse);
1698+ }
1699+
1700+ snprintf (name, 256 , " %s " , base);
1701+
1702+ id <MTLComputePipelineState > res = ggml_metal_get_kernel (ctx, name);
1703+ if (res) {
1704+ // kernel found
1705+ return res;
1706+ }
1707+
1708+ return ggml_metal_compile_kernel (backend, base, name, nil );
1709+ }
1710+ }
1711+
1712+ static id <MTLComputePipelineState > ggml_metal_get_pipeline_rms_norm (
1713+ ggml_backend_t backend, struct ggml_tensor * op,
1714+ int32_t n_fuse) {
1715+ struct ggml_backend_metal_context * ctx = backend->context ;
1716+
1717+ char base[256 ];
1718+ char name[256 ];
1719+
1720+ @autoreleasepool {
1721+ switch (n_fuse) {
1722+ case 1 : snprintf (base, 256 , " kernel_rms_norm" ); break ;
1723+ case 2 : snprintf (base, 256 , " kernel_rms_norm_mul" ); break ;
1724+ case 3 : snprintf (base, 256 , " kernel_rms_norm_mul_add" ); break ;
1725+ default : GGML_ABORT (" fatal error" );
1726+ }
1727+
1728+ snprintf (name, 256 , " %s " , base);
1729+
1730+ id <MTLComputePipelineState > res = ggml_metal_get_kernel (ctx, name);
1731+ if (res) {
1732+ // kernel found
1733+ return res;
1734+ }
1735+
1736+ return ggml_metal_compile_kernel (backend, base, name, nil );
1737+ }
1738+
1739+ GGML_UNUSED (op);
1740+ }
1741+
17251742static void ggml_metal_free (struct ggml_backend_metal_context * ctx) {
17261743 GGML_LOG_INFO (" %s : deallocating\n " , __func__);
17271744
@@ -2359,8 +2376,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
23592376
23602377 bool bcast_row = false ;
23612378
2362- id <MTLComputePipelineState > pipeline = nil ;
2363-
23642379 ggml_metal_kargs_bin args = {
23652380 /* .ne00 =*/ ne00,
23662381 /* .ne01 =*/ ne01,
@@ -2441,55 +2456,19 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
24412456 }
24422457 }
24432458
2459+ id <MTLComputePipelineState > pipeline = nil ;
2460+
24442461 if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
24452462 GGML_ASSERT (ggml_is_contiguous (src0));
24462463
24472464 // src1 is a row
24482465 GGML_ASSERT (ne11 == 1 );
24492466
2450- switch (dst->op ) {
2451- case GGML_OP_ADD:
2452- {
2453- switch (n_fuse) {
2454- case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline ; break ;
2455- case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline ; break ;
2456- case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline ; break ;
2457- case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline ; break ;
2458- case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline ; break ;
2459- case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline ; break ;
2460- case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline ; break ;
2461- case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline ; break ;
2462- default : GGML_ABORT (" fatal error" );
2463- }
2464- } break ;
2465- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline ; break ;
2466- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline ; break ;
2467- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline ; break ;
2468- default : GGML_ABORT (" fatal error" );
2469- }
2467+ pipeline = ggml_metal_get_pipeline_bin (backend, dst->op , n_fuse, true );
24702468
24712469 bcast_row = true ;
24722470 } else {
2473- switch (dst->op ) {
2474- case GGML_OP_ADD:
2475- {
2476- switch (n_fuse) {
2477- case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD ].pipeline ; break ;
2478- case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline ; break ;
2479- case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline ; break ;
2480- case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline ; break ;
2481- case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline ; break ;
2482- case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline ; break ;
2483- case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline ; break ;
2484- case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline ; break ;
2485- default : GGML_ABORT (" fatal error" );
2486- }
2487- } break ;
2488- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
2489- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
2490- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
2491- default : GGML_ABORT (" fatal error" );
2492- }
2471+ pipeline = ggml_metal_get_pipeline_bin (backend, dst->op , n_fuse, false );
24932472 }
24942473
24952474 if (n_fuse > 1 ) {
@@ -2650,8 +2629,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26502629 ggml_metal_encode_concurrency_reset (ctx_enc);
26512630 }
26522631
2653- const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ;
2654-
26552632 ggml_metal_kargs_bin args = {
26562633 /* .ne00 =*/ ne00,
26572634 /* .ne01 =*/ ne01,
@@ -2681,6 +2658,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26812658 /* .o1 =*/ { offs_src1},
26822659 };
26832660
2661+ // const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
2662+ const id <MTLComputePipelineState > pipeline = ggml_metal_get_pipeline_bin (backend, GGML_OP_ADD, 1 , false );
2663+
26842664 [encoder setComputePipelineState: pipeline];
26852665 [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
26862666 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
@@ -4659,14 +4639,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
46594639 }
46604640 }
46614641
4662- id <MTLComputePipelineState > pipeline;
4663-
4664- switch (n_fuse) {
4665- case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline ; break ;
4666- case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline ; break ;
4667- case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline ; break ;
4668- default : GGML_ABORT (" unsupported n_fuse = %d \n " , n_fuse);
4669- }
4642+ const id <MTLComputePipelineState > pipeline = ggml_metal_get_pipeline_rms_norm (backend, node, n_fuse);
46704643
46714644 int nth = 32 ; // SIMD width
46724645
0 commit comments