@@ -154,20 +154,20 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
154154 GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
155155 GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
156156 GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
157- GGML_METAL_KERNEL_TYPE_ADD_ROW ,
158- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2 ,
159- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3 ,
160- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4 ,
161- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5 ,
162- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6 ,
163- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7 ,
164- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8 ,
157+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ,
158+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2 ,
159+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3 ,
160+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4 ,
161+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5 ,
162+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6 ,
163+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7 ,
164+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8 ,
165165 GGML_METAL_KERNEL_TYPE_SUB,
166- GGML_METAL_KERNEL_TYPE_SUB_ROW ,
166+ GGML_METAL_KERNEL_TYPE_SUB_ROW_C4 ,
167167 GGML_METAL_KERNEL_TYPE_MUL,
168- GGML_METAL_KERNEL_TYPE_MUL_ROW ,
168+ GGML_METAL_KERNEL_TYPE_MUL_ROW_C4 ,
169169 GGML_METAL_KERNEL_TYPE_DIV,
170- GGML_METAL_KERNEL_TYPE_DIV_ROW ,
170+ GGML_METAL_KERNEL_TYPE_DIV_ROW_C4 ,
171171 GGML_METAL_KERNEL_TYPE_REPEAT_F32,
172172 GGML_METAL_KERNEL_TYPE_REPEAT_F16,
173173 GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -1150,20 +1150,20 @@ @implementation GGMLMetalClass
11501150 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
11511151 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true );
11521152 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
1153- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW , add_row, true );
1154- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2 , add_row_fuse_2, true );
1155- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3 , add_row_fuse_3, true );
1156- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4 , add_row_fuse_4, true );
1157- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5 , add_row_fuse_5, true );
1158- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6 , add_row_fuse_6, true );
1159- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7 , add_row_fuse_7, true );
1160- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8 , add_row_fuse_8, true );
1153+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 , add_row_c4, true );
1154+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2 , add_row_c4_fuse_2, true );
1155+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3 , add_row_c4_fuse_3, true );
1156+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4 , add_row_c4_fuse_4, true );
1157+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5 , add_row_c4_fuse_5, true );
1158+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6 , add_row_c4_fuse_6, true );
1159+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7 , add_row_c4_fuse_7, true );
1160+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8 , add_row_c4_fuse_8, true );
11611161 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
1162- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW , sub_row, true );
1162+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW_C4 , sub_row_c4, true );
11631163 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
1164- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW , mul_row, true );
1164+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW_C4 , mul_row_c4, true );
11651165 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV, div, true );
1166- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW , div_row, true );
1166+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW_C4 , div_row_c4, true );
11671167 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true );
11681168 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true );
11691169 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true );
@@ -2149,6 +2149,8 @@ static int ggml_metal_encode_node(
21492149 ++n_fuse;
21502150 }
21512151
2152+ // GGML_LOG_INFO("%s: XXXXXXXXXXXXXXXXXXX n_fuse = %d\n", __func__, n_fuse);
2153+
21522154 if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
21532155 GGML_ASSERT (ggml_is_contiguous (src0));
21542156
@@ -2159,20 +2161,20 @@ static int ggml_metal_encode_node(
21592161 case GGML_OP_ADD:
21602162 {
21612163 switch (n_fuse) {
2162- case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW ].pipeline ; break ;
2163- case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2 ].pipeline ; break ;
2164- case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3 ].pipeline ; break ;
2165- case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4 ].pipeline ; break ;
2166- case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5 ].pipeline ; break ;
2167- case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6 ].pipeline ; break ;
2168- case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7 ].pipeline ; break ;
2169- case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8 ].pipeline ; break ;
2164+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline ; break ;
2165+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2 ].pipeline ; break ;
2166+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3 ].pipeline ; break ;
2167+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4 ].pipeline ; break ;
2168+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5 ].pipeline ; break ;
2169+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6 ].pipeline ; break ;
2170+ case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7 ].pipeline ; break ;
2171+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8 ].pipeline ; break ;
21702172 default : GGML_ABORT (" fatal error" );
21712173 }
21722174 } break ;
2173- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW ].pipeline ; break ;
2174- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW ].pipeline ; break ;
2175- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW ].pipeline ; break ;
2175+ case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW_C4 ].pipeline ; break ;
2176+ case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW_C4 ].pipeline ; break ;
2177+ case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW_C4 ].pipeline ; break ;
21762178 default : GGML_ABORT (" fatal error" );
21772179 }
21782180
@@ -2207,11 +2209,7 @@ static int ggml_metal_encode_node(
22072209 [encoder setComputePipelineState: pipeline];
22082210 [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
22092211 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2210- if (dst->op == GGML_OP_ADD) {
2211- [encoder setBuffer: id_src1 offset: 0 atIndex: 2 ];
2212- } else {
2213- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2214- }
2212+ [encoder setBuffer: id_src1 offset: 0 atIndex: 2 ];
22152213 [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
22162214
22172215 if (bcast_row) {
0 commit comments