@@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
10901090    return  res;
10911091}
10921092
1093- ggml_metal_pipeline_t  ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t  lib, const  ggml_tensor * op, int32_t  n_fuse) {
1094-     assert (op->op  == GGML_OP_RMS_NORM);
1095- 
1096-     GGML_ASSERT (op->src [0 ]->ne [0 ] % 4  == 0 );
1097-     GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
1098- 
1099-     char  base[256 ];
1100-     char  name[256 ];
1101- 
1102-     switch  (n_fuse) {
1103-         case  1 : snprintf (base, 256 , " kernel_rms_norm_f32" break ;
1104-         case  2 : snprintf (base, 256 , " kernel_rms_norm_mul_f32" break ;
1105-         case  3 : snprintf (base, 256 , " kernel_rms_norm_mul_add_f32" break ;
1106-         default : GGML_ABORT (" fatal error" 
1107-     }
1108- 
1109-     snprintf (name, 256 , " %s" 
1110- 
1111-     ggml_metal_pipeline_t  res = ggml_metal_library_get_pipeline (lib, name);
1112-     if  (res) {
1113-         return  res;
1114-     }
1115- 
1116-     res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
1117- 
1118-     ggml_metal_pipeline_set_smem (res, 32 *sizeof (float ));
1119- 
1120-     return  res;
1121- }
1122- 
11231093ggml_metal_pipeline_t  ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t  lib, const  ggml_tensor * op) {
11241094    assert (op->op  == GGML_OP_L2_NORM);
11251095
@@ -1167,16 +1137,33 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
11671137    return  res;
11681138}
11691139
1170- ggml_metal_pipeline_t  ggml_metal_library_get_pipeline_norm (ggml_metal_library_t  lib, const  ggml_tensor * op) {
1171-     assert (op->op  == GGML_OP_NORM);
1140+ ggml_metal_pipeline_t  ggml_metal_library_get_pipeline_norm (ggml_metal_library_t  lib, const  ggml_tensor * op,  int  n_fuse ) {
1141+     assert (op->op  == GGML_OP_NORM || op-> op  == GGML_OP_RMS_NORM );
11721142
11731143    GGML_ASSERT (op->src [0 ]->ne [0 ] % 4  == 0 );
1174-     GGML_ASSERT (ggml_is_contiguous_1 (op->src [0 ]));
1144+     GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
11751145
11761146    char  base[256 ];
11771147    char  name[256 ];
11781148
1179-     snprintf (base, 256 , " kernel_norm_f32" 
1149+     switch  (op->op ) {
1150+         case  GGML_OP_NORM:
1151+             switch  (n_fuse) {
1152+                 case  1 : snprintf (base, 256 , " kernel_norm_f32" break ;
1153+                 case  2 : snprintf (base, 256 , " kernel_norm_mul_f32" break ;
1154+                 case  3 : snprintf (base, 256 , " kernel_norm_mul_add_f32" break ;
1155+                 default : GGML_ABORT (" fatal error" 
1156+             } break ;
1157+         case  GGML_OP_RMS_NORM:
1158+             switch  (n_fuse) {
1159+                 case  1 : snprintf (base, 256 , " kernel_rms_norm_f32" break ;
1160+                 case  2 : snprintf (base, 256 , " kernel_rms_norm_mul_f32" break ;
1161+                 case  3 : snprintf (base, 256 , " kernel_rms_norm_mul_add_f32" break ;
1162+                 default : GGML_ABORT (" fatal error" 
1163+             } break ;
1164+         default : GGML_ABORT (" fatal error" 
1165+     }
1166+ 
11801167    snprintf (name, 256 , " %s" 
11811168
11821169    ggml_metal_pipeline_t  res = ggml_metal_library_get_pipeline (lib, name);
0 commit comments