193193 // GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
194194 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195195 // GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196- GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
197196 GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
197+ GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
198+ GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
199+ GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
198200 GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
199201 GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
200202 GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
201203 GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
202204 GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
203205 GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
204- GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
205- GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
206206 GGML_METAL_KERNEL_TYPE_CONCAT,
207207 GGML_METAL_KERNEL_TYPE_SQR,
208208 GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -651,14 +651,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
651651 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
652652 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
653653 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
654+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
655+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
654656 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
655657 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true );
656658 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true );
657659 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true );
658660 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true );
659661 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true );
660- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
661- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
662662 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONCAT, concat, true );
663663 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQR, sqr, true );
664664 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
@@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
810810 switch (op->src [0 ]->type ) {
811811 case GGML_TYPE_F32:
812812 switch (op->type ) {
813- case GGML_TYPE_F16:
814813 case GGML_TYPE_F32:
814+ case GGML_TYPE_F16:
815815 case GGML_TYPE_Q8_0:
816816 case GGML_TYPE_Q4_0:
817817 case GGML_TYPE_Q4_1:
@@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
824824 }
825825 case GGML_TYPE_F16:
826826 switch (op->type ) {
827- case GGML_TYPE_F16:
828827 case GGML_TYPE_F32:
828+ case GGML_TYPE_F16:
829829 return true ;
830830 default :
831831 return false ;
@@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
837837 case GGML_OP_DIAG_MASK_INF:
838838 case GGML_OP_GET_ROWS:
839839 {
840- return op->src [ 0 ]-> type != GGML_TYPE_BF16 && op-> ne [3 ] == 1 ;
840+ return op->ne [3 ] == 1 ;
841841 }
842842 default :
843843 return false ;
@@ -1580,8 +1580,8 @@ static enum ggml_status ggml_metal_graph_compute(
15801580 // some Metal matrix data types require aligned pointers
15811581 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
15821582 switch (src0->type ) {
1583- case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1584- case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1583+ case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1584+ case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
15851585 default : break ;
15861586 }
15871587
@@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
27752775 GGML_ASSERT (ne0 % ggml_blck_size (dst->type ) == 0 );
27762776
27772777 switch (dstt) {
2778- case GGML_TYPE_F16 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F16 ].pipeline ; break ;
2779- case GGML_TYPE_F32 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32 ].pipeline ; break ;
2778+ case GGML_TYPE_F32 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32 ].pipeline ; break ;
2779+ case GGML_TYPE_F16 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F16 ].pipeline ; break ;
27802780 case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline ; break ;
27812781 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline ; break ;
27822782 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline ; break ;
@@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
27892789 case GGML_TYPE_F16:
27902790 {
27912791 switch (dstt) {
2792- case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F16_F16 ].pipeline ; break ;
2793- case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F16_F32 ].pipeline ; break ;
2792+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F16_F32 ].pipeline ; break ;
2793+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F16_F16 ].pipeline ; break ;
27942794 default : GGML_ASSERT (false && " not implemented" );
27952795 };
27962796 } break ;
0 commit comments