@@ -232,28 +232,6 @@ - (void) dealloc {
232
232
@end
233
233
234
234
enum ggml_metal_kernel_type {
235
- GGML_METAL_KERNEL_TYPE_ADD,
236
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
237
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
238
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
239
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
240
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
241
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
242
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
243
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
244
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
245
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
246
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
247
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
248
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
249
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
250
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
251
- GGML_METAL_KERNEL_TYPE_SUB,
252
- GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
253
- GGML_METAL_KERNEL_TYPE_MUL,
254
- GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
255
- GGML_METAL_KERNEL_TYPE_DIV,
256
- GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
257
235
GGML_METAL_KERNEL_TYPE_ADD_ID,
258
236
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
259
237
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
@@ -319,9 +297,6 @@ - (void) dealloc {
319
297
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
320
298
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
321
299
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
322
- GGML_METAL_KERNEL_TYPE_RMS_NORM,
323
- GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
324
- GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
325
300
GGML_METAL_KERNEL_TYPE_L2_NORM,
326
301
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
327
302
GGML_METAL_KERNEL_TYPE_NORM,
@@ -1177,28 +1152,6 @@ @implementation GGMLMetalClass
1177
1152
1178
1153
// simd_sum and simd_max requires MTLGPUFamilyApple7
1179
1154
1180
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD, add, true );
1181
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true );
1182
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true );
1183
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true );
1184
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true );
1185
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
1186
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true );
1187
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
1188
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true );
1189
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true );
1190
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true );
1191
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true );
1192
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true );
1193
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true );
1194
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true );
1195
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true );
1196
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
1197
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true );
1198
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
1199
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true );
1200
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV, div, true );
1201
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true );
1202
1155
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true );
1203
1156
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true );
1204
1157
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true );
@@ -1264,9 +1217,6 @@ @implementation GGMLMetalClass
1264
1217
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true );
1265
1218
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true );
1266
1219
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
1267
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1268
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1269
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
1270
1220
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1271
1221
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1272
1222
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
@@ -1722,6 +1672,73 @@ @implementation GGMLMetalClass
1722
1672
GGML_UNUSED (op);
1723
1673
}
1724
1674
1675
+ static id <MTLComputePipelineState > ggml_metal_get_pipeline_bin (
1676
+ ggml_backend_t backend, enum ggml_op op,
1677
+ int32_t n_fuse,
1678
+ bool row) {
1679
+ struct ggml_backend_metal_context * ctx = backend->context ;
1680
+
1681
+ char base[256 ];
1682
+ char name[256 ];
1683
+
1684
+ @autoreleasepool {
1685
+ const char * op_str = " undefined" ;
1686
+ switch (op) {
1687
+ case GGML_OP_ADD: op_str = " add" ; break ;
1688
+ case GGML_OP_SUB: op_str = " sub" ; break ;
1689
+ case GGML_OP_MUL: op_str = " mul" ; break ;
1690
+ case GGML_OP_DIV: op_str = " div" ; break ;
1691
+ default : GGML_ABORT (" fatal error" );
1692
+ };
1693
+
1694
+ if (row) {
1695
+ snprintf (base, 256 , " kernel_%s _row_c4_fuse_%d " , op_str, n_fuse);
1696
+ } else {
1697
+ snprintf (base, 256 , " kernel_%s _fuse_%d " , op_str, n_fuse);
1698
+ }
1699
+
1700
+ snprintf (name, 256 , " %s " , base);
1701
+
1702
+ id <MTLComputePipelineState > res = ggml_metal_get_kernel (ctx, name);
1703
+ if (res) {
1704
+ // kernel found
1705
+ return res;
1706
+ }
1707
+
1708
+ return ggml_metal_compile_kernel (backend, base, name, nil );
1709
+ }
1710
+ }
1711
+
1712
+ static id <MTLComputePipelineState > ggml_metal_get_pipeline_rms_norm (
1713
+ ggml_backend_t backend, struct ggml_tensor * op,
1714
+ int32_t n_fuse) {
1715
+ struct ggml_backend_metal_context * ctx = backend->context ;
1716
+
1717
+ char base[256 ];
1718
+ char name[256 ];
1719
+
1720
+ @autoreleasepool {
1721
+ switch (n_fuse) {
1722
+ case 1 : snprintf (base, 256 , " kernel_rms_norm" ); break ;
1723
+ case 2 : snprintf (base, 256 , " kernel_rms_norm_mul" ); break ;
1724
+ case 3 : snprintf (base, 256 , " kernel_rms_norm_mul_add" ); break ;
1725
+ default : GGML_ABORT (" fatal error" );
1726
+ }
1727
+
1728
+ snprintf (name, 256 , " %s " , base);
1729
+
1730
+ id <MTLComputePipelineState > res = ggml_metal_get_kernel (ctx, name);
1731
+ if (res) {
1732
+ // kernel found
1733
+ return res;
1734
+ }
1735
+
1736
+ return ggml_metal_compile_kernel (backend, base, name, nil );
1737
+ }
1738
+
1739
+ GGML_UNUSED (op);
1740
+ }
1741
+
1725
1742
static void ggml_metal_free (struct ggml_backend_metal_context * ctx) {
1726
1743
GGML_LOG_INFO (" %s : deallocating\n " , __func__);
1727
1744
@@ -2359,8 +2376,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
2359
2376
2360
2377
bool bcast_row = false ;
2361
2378
2362
- id <MTLComputePipelineState > pipeline = nil ;
2363
-
2364
2379
ggml_metal_kargs_bin args = {
2365
2380
/* .ne00 =*/ ne00,
2366
2381
/* .ne01 =*/ ne01,
@@ -2441,55 +2456,19 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
2441
2456
}
2442
2457
}
2443
2458
2459
+ id <MTLComputePipelineState > pipeline = nil ;
2460
+
2444
2461
if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
2445
2462
GGML_ASSERT (ggml_is_contiguous (src0));
2446
2463
2447
2464
// src1 is a row
2448
2465
GGML_ASSERT (ne11 == 1 );
2449
2466
2450
- switch (dst->op ) {
2451
- case GGML_OP_ADD:
2452
- {
2453
- switch (n_fuse) {
2454
- case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline ; break ;
2455
- case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline ; break ;
2456
- case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline ; break ;
2457
- case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline ; break ;
2458
- case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline ; break ;
2459
- case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline ; break ;
2460
- case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline ; break ;
2461
- case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline ; break ;
2462
- default : GGML_ABORT (" fatal error" );
2463
- }
2464
- } break ;
2465
- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline ; break ;
2466
- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline ; break ;
2467
- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline ; break ;
2468
- default : GGML_ABORT (" fatal error" );
2469
- }
2467
+ pipeline = ggml_metal_get_pipeline_bin (backend, dst->op , n_fuse, true );
2470
2468
2471
2469
bcast_row = true ;
2472
2470
} else {
2473
- switch (dst->op ) {
2474
- case GGML_OP_ADD:
2475
- {
2476
- switch (n_fuse) {
2477
- case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD ].pipeline ; break ;
2478
- case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline ; break ;
2479
- case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline ; break ;
2480
- case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline ; break ;
2481
- case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline ; break ;
2482
- case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline ; break ;
2483
- case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline ; break ;
2484
- case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline ; break ;
2485
- default : GGML_ABORT (" fatal error" );
2486
- }
2487
- } break ;
2488
- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
2489
- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
2490
- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
2491
- default : GGML_ABORT (" fatal error" );
2492
- }
2471
+ pipeline = ggml_metal_get_pipeline_bin (backend, dst->op , n_fuse, false );
2493
2472
}
2494
2473
2495
2474
if (n_fuse > 1 ) {
@@ -2650,8 +2629,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
2650
2629
ggml_metal_encode_concurrency_reset (ctx_enc);
2651
2630
}
2652
2631
2653
- const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ;
2654
-
2655
2632
ggml_metal_kargs_bin args = {
2656
2633
/* .ne00 =*/ ne00,
2657
2634
/* .ne01 =*/ ne01,
@@ -2681,6 +2658,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
2681
2658
/* .o1 =*/ { offs_src1},
2682
2659
};
2683
2660
2661
+ // const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
2662
+ const id <MTLComputePipelineState > pipeline = ggml_metal_get_pipeline_bin (backend, GGML_OP_ADD, 1 , false );
2663
+
2684
2664
[encoder setComputePipelineState: pipeline];
2685
2665
[encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2686
2666
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
@@ -4659,14 +4639,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
4659
4639
}
4660
4640
}
4661
4641
4662
- id <MTLComputePipelineState > pipeline;
4663
-
4664
- switch (n_fuse) {
4665
- case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline ; break ;
4666
- case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline ; break ;
4667
- case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline ; break ;
4668
- default : GGML_ABORT (" unsupported n_fuse = %d \n " , n_fuse);
4669
- }
4642
+ const id <MTLComputePipelineState > pipeline = ggml_metal_get_pipeline_rms_norm (backend, node, n_fuse);
4670
4643
4671
4644
int nth = 32 ; // SIMD width
4672
4645
0 commit comments