@@ -226,6 +226,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
226226 GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
227227 GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
228228 GGML_METAL_KERNEL_TYPE_RMS_NORM,
229+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
230+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
229231 GGML_METAL_KERNEL_TYPE_L2_NORM,
230232 GGML_METAL_KERNEL_TYPE_GROUP_NORM,
231233 GGML_METAL_KERNEL_TYPE_NORM,
@@ -1222,6 +1224,8 @@ @implementation GGMLMetalClass
12221224 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true );
12231225 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
12241226 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1227+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1228+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
12251229 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
12261230 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
12271231 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
@@ -2115,6 +2119,10 @@ static int ggml_metal_encode_node(
21152119 /* .o1 =*/ { offs_src1 },
21162120 };
21172121
2122+ // c[0] = add(a, b[0])
2123+ // c[1] = add(c[0], b[1])
2124+ // c[2] = add(c[1], b[2])
2125+ // ...
21182126 {
21192127 ops[0 ] = GGML_OP_ADD;
21202128 ops[1 ] = GGML_OP_ADD;
@@ -2133,6 +2141,11 @@ static int ggml_metal_encode_node(
21332141 break ;
21342142 }
21352143
2144+ if (nodes[n_fuse] != nodes[n_fuse + 1 ]->src [0 ]) {
2145+ break ;
2146+ }
2147+
2148+ // b[0] === b[1] === ...
21362149 if (!ggml_are_same_layout (nodes[n_fuse]->src [1 ], nodes[n_fuse + 1 ]->src [1 ])) {
21372150 break ;
21382151 }
@@ -4123,12 +4136,86 @@ static int ggml_metal_encode_node(
41234136 case GGML_OP_RMS_NORM:
41244137 {
41254138 GGML_ASSERT (ne00 % 4 == 0 );
4126- GGML_ASSERT (ggml_is_contiguous_1 (src0));
4139+ GGML_ASSERT (ggml_is_contiguous_rows (src0));
41274140
41284141 float eps;
41294142 memcpy (&eps, dst->op_params , sizeof (float ));
41304143
4131- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline ;
4144+ ggml_metal_kargs_rms_norm args = {
4145+ /* .ne00 =*/ ne00,
4146+ /* .ne00_4 =*/ ne00/4 ,
4147+ /* .nb1 =*/ nb1,
4148+ /* .nb2 =*/ nb2,
4149+ /* .nb3 =*/ nb3,
4150+ /* .eps =*/ eps,
4151+ /* .nef1 =*/ { ne01 },
4152+ /* .nef2 =*/ { ne02 },
4153+ /* .nef3 =*/ { ne03 },
4154+ /* .nbf1 =*/ { nb01 },
4155+ /* .nbf2 =*/ { nb02 },
4156+ /* .nbf3 =*/ { nb03 },
4157+ };
4158+
4159+ size_t offs_fuse[2 ] = { 0 , 0 };
4160+ id <MTLBuffer > id_fuse[2 ] = { id_src0, id_src0 };
4161+
4162+ // d[0] = rms_norm(a)
4163+ // d[1] = mul(d[0], b)
4164+ // d[2] = add(d[1], c)
4165+ {
4166+ ops[0 ] = GGML_OP_RMS_NORM;
4167+ ops[1 ] = GGML_OP_MUL;
4168+ ops[2 ] = GGML_OP_ADD;
4169+
4170+ for (n_fuse = 0 ; n_fuse <= 1 ; ++n_fuse) {
4171+ if (!ggml_can_fuse (gf, idx + n_fuse, ops + n_fuse, 2 )) {
4172+ break ;
4173+ }
4174+
4175+ if (nodes[n_fuse] != nodes[n_fuse + 1 ]->src [0 ]) {
4176+ break ;
4177+ }
4178+
4179+ if (nodes[n_fuse + 1 ]->src [1 ]->ne [0 ] != node->ne [0 ]) {
4180+ break ;
4181+ }
4182+
4183+ if (!ggml_is_contiguous_rows (nodes[n_fuse + 1 ]->src [1 ])) {
4184+ break ;
4185+ }
4186+
4187+ if (nodes[n_fuse + 1 ]->type != GGML_TYPE_F32) {
4188+ break ;
4189+ }
4190+
4191+ id_fuse[n_fuse] = ggml_metal_get_buffer (nodes[n_fuse + 1 ]->src [1 ], &offs_fuse[n_fuse]);
4192+
4193+ args.nef1 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [1 ];
4194+ args.nef2 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [2 ];
4195+ args.nef3 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [3 ];
4196+
4197+ args.nbf1 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->nb [1 ];
4198+ args.nbf2 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->nb [2 ];
4199+ args.nbf3 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->nb [3 ];
4200+ }
4201+
4202+ ++n_fuse;
4203+ }
4204+
4205+ // GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);
4206+
4207+ if (n_fuse > 1 ) {
4208+ id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
4209+ }
4210+
4211+ id <MTLComputePipelineState > pipeline;
4212+
4213+ switch (n_fuse) {
4214+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline ; break ;
4215+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline ; break ;
4216+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline ; break ;
4217+ default : GGML_ABORT (" unsupported n_fuse = %d \n " , n_fuse);
4218+ }
41324219
41334220 int nth = 32 ; // SIMD width
41344221
@@ -4139,23 +4226,16 @@ static int ggml_metal_encode_node(
41394226 nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
41404227 nth = MIN (nth, ne00/4 );
41414228
4142- ggml_metal_kargs_rms_norm args = {
4143- /* .ne00 =*/ ne00,
4144- /* .ne00_4 =*/ ne00/4 ,
4145- /* .nb01 =*/ nb01,
4146- /* .eps =*/ eps,
4147- };
4148-
41494229 [encoder setComputePipelineState: pipeline];
4150- [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4151- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4152- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
4230+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4231+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4232+ [encoder setBuffer: id_fuse[0 ] offset: offs_fuse[0 ] atIndex: 2 ];
4233+ [encoder setBuffer: id_fuse[1 ] offset: offs_fuse[1 ] atIndex: 3 ];
4234+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
41534235
41544236 [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
41554237
4156- const int64_t nrows = ggml_nrows (src0);
4157-
4158- [encoder dispatchThreadgroups: MTLSizeMake (nrows, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4238+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
41594239 } break ;
41604240 case GGML_OP_L2_NORM:
41614241 {
0 commit comments