@@ -2597,91 +2597,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
25972597template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 >;
25982598// template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
25992599
2600- kernel void kernel_cpy_f16_f16 (
2601- device const half * src0,
2602- device half * dst,
2603- constant int64_t & ne00,
2604- constant int64_t & ne01,
2605- constant int64_t & ne02,
2606- constant int64_t & ne03,
2607- constant uint64_t & nb00,
2608- constant uint64_t & nb01,
2609- constant uint64_t & nb02,
2610- constant uint64_t & nb03,
2611- constant int64_t & ne0,
2612- constant int64_t & ne1,
2613- constant int64_t & ne2,
2614- constant int64_t & ne3,
2615- constant uint64_t & nb0,
2616- constant uint64_t & nb1,
2617- constant uint64_t & nb2,
2618- constant uint64_t & nb3,
2619- uint3 tgpig[[threadgroup_position_in_grid]],
2620- uint3 tpitg[[thread_position_in_threadgroup]],
2621- uint3 ntg[[threads_per_threadgroup]]) {
2622- const int64_t i03 = tgpig[2 ];
2623- const int64_t i02 = tgpig[1 ];
2624- const int64_t i01 = tgpig[0 ];
2625-
2626- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2627-
2628- const int64_t i3 = n / (ne2*ne1*ne0);
2629- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2630- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2631- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2632-
2633- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2634-
2635- for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
2636- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2637- dst_data[i00] = src[0 ];
2638- }
2639- }
2640-
2641- kernel void kernel_cpy_f16_f32 (
2642- device const half * src0,
2643- device float * dst,
2644- constant int64_t & ne00,
2645- constant int64_t & ne01,
2646- constant int64_t & ne02,
2647- constant int64_t & ne03,
2648- constant uint64_t & nb00,
2649- constant uint64_t & nb01,
2650- constant uint64_t & nb02,
2651- constant uint64_t & nb03,
2652- constant int64_t & ne0,
2653- constant int64_t & ne1,
2654- constant int64_t & ne2,
2655- constant int64_t & ne3,
2656- constant uint64_t & nb0,
2657- constant uint64_t & nb1,
2658- constant uint64_t & nb2,
2659- constant uint64_t & nb3,
2660- uint3 tgpig[[threadgroup_position_in_grid]],
2661- uint3 tpitg[[thread_position_in_threadgroup]],
2662- uint3 ntg[[threads_per_threadgroup]]) {
2663- const int64_t i03 = tgpig[2 ];
2664- const int64_t i02 = tgpig[1 ];
2665- const int64_t i01 = tgpig[0 ];
2666-
2667- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2668-
2669- const int64_t i3 = n / (ne2*ne1*ne0);
2670- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2671- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2672- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2673-
2674- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2675-
2676- for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
2677- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2678- dst_data[i00] = src[0 ];
2679- }
2680- }
2681-
2682- kernel void kernel_cpy_f32_f16 (
2683- device const float * src0,
2684- device half * dst,
2600+ template <typename T0, typename T1>
2601+ kernel void kernel_cpy (
2602+ device const void * src0,
2603+ device void * dst,
26852604 constant int64_t & ne00,
26862605 constant int64_t & ne01,
26872606 constant int64_t & ne02,
@@ -2712,56 +2631,22 @@ kernel void kernel_cpy_f32_f16(
27122631 const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
27132632 const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
27142633
2715- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2634+ device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
27162635
27172636 for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
2718- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2719-
2720- dst_data[i00] = src[0 ];
2637+ device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2638+ dst_data[i00] = (T1) src[0 ];
27212639 }
27222640}
27232641
2724- kernel void kernel_cpy_f32_f32 (
2725- device const float * src0,
2726- device float * dst,
2727- constant int64_t & ne00,
2728- constant int64_t & ne01,
2729- constant int64_t & ne02,
2730- constant int64_t & ne03,
2731- constant uint64_t & nb00,
2732- constant uint64_t & nb01,
2733- constant uint64_t & nb02,
2734- constant uint64_t & nb03,
2735- constant int64_t & ne0,
2736- constant int64_t & ne1,
2737- constant int64_t & ne2,
2738- constant int64_t & ne3,
2739- constant uint64_t & nb0,
2740- constant uint64_t & nb1,
2741- constant uint64_t & nb2,
2742- constant uint64_t & nb3,
2743- uint3 tgpig[[threadgroup_position_in_grid]],
2744- uint3 tpitg[[thread_position_in_threadgroup]],
2745- uint3 ntg[[threads_per_threadgroup]]) {
2746- const int64_t i03 = tgpig[2 ];
2747- const int64_t i02 = tgpig[1 ];
2748- const int64_t i01 = tgpig[0 ];
2749-
2750- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2751-
2752- const int64_t i3 = n / (ne2*ne1*ne0);
2753- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2754- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2755- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2756-
2757- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2642+ typedef decltype (kernel_cpy<float , float >) kernel_cpy_t;
27582643
2759- for ( int64_t i00 = tpitg. x ; i00 < ne00; i00 += ntg. x ) {
2760- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00) ;
2761-
2762- dst_data[i00] = src[ 0 ] ;
2763- }
2764- }
2644+ template [[host_name( " kernel_cpy_f32_f32 " )]] kernel kernel_cpy_t kernel_cpy< float , float >;
2645+ template [[host_name( " kernel_cpy_f32_bf16 " )]] kernel kernel_cpy_t kernel_cpy< float , bfloat> ;
2646+ template [[host_name( " kernel_cpy_f32_f16 " )]] kernel kernel_cpy_t kernel_cpy< float , half>;
2647+ template [[host_name( " kernel_cpy_bf16_f32 " )]] kernel kernel_cpy_t kernel_cpy<bfloat, float > ;
2648+ template [[host_name( " kernel_cpy_f16_f16 " )]] kernel kernel_cpy_t kernel_cpy<half, half>;
2649+ template [[host_name( " kernel_cpy_f16_f32 " )]] kernel kernel_cpy_t kernel_cpy<half, float >;
27652650
27662651kernel void kernel_cpy_f32_q8_0 (
27672652 device const float * src0,
0 commit comments