@@ -2723,45 +2723,9 @@ kernel void kernel_leaky_relu_f32(
27232723 dst[tpig] = src0[tpig] > 0 .0f ? src0[tpig] : src0[tpig] * slope;
27242724}
27252725
2726- typedef void (flash_attn_ext_t )(
2727- device const char * q,
2728- device const char * k,
2729- device const char * v,
2730- device const char * mask,
2731- device float * dst,
2732- constant int64_t & ne01,
2733- constant int64_t & ne02,
2734- constant int64_t & ne03,
2735- constant uint64_t & nb01,
2736- constant uint64_t & nb02,
2737- constant uint64_t & nb03,
2738- constant int64_t & ne11,
2739- constant int64_t & ne12,
2740- constant int64_t & ne13,
2741- constant uint64_t & nb11,
2742- constant uint64_t & nb12,
2743- constant uint64_t & nb13,
2744- constant uint64_t & nb21,
2745- constant uint64_t & nb22,
2746- constant uint64_t & nb23,
2747- constant uint64_t & nb31,
2748- constant int64_t & ne1,
2749- constant int64_t & ne2,
2750- constant float & scale,
2751- constant float & max_bias,
2752- constant float & m0,
2753- constant float & m1,
2754- constant uint32_t & n_head_log2,
2755- constant float & logit_softcap,
2756- threadgroup half * shared,
2757- uint3 tgpig[[threadgroup_position_in_grid]],
2758- uint3 tpitg[[thread_position_in_threadgroup]],
2759- uint3 ntg[[threads_per_threadgroup]],
2760- ushort tiisg[[thread_index_in_simdgroup]],
2761- ushort sgitg[[simdgroup_index_in_threadgroup]]);
2762-
27632726// ref: https://arxiv.org/pdf/2307.08691.pdf
2764- template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread half4x4 &), short D, short Q = 8 , short K = 8 , short C = 32 > // head size, queries per threadgroup, cache items per threadgroup
2727+ // D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
2728+ template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread half4x4 &), short D, short Q = 8 , short KV = 8 , short C = 32 >
27652729kernel void kernel_flash_attn_ext (
27662730 device const char * q,
27672731 device const char * k,
@@ -2818,8 +2782,8 @@ kernel void kernel_flash_attn_ext(
28182782 threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
28192783 threadgroup float * ss = (threadgroup float *) (shared + 2 *sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
28202784
2821- threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4 *16 *K ) + Q*T); // scratch buffer to load K and V in shared memory
2822- threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4 *16 *K ) + Q*T); // same as above but in half4x4
2785+ threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4 *16 *KV ) + Q*T); // scratch buffer to load K and V in shared memory
2786+ threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4 *16 *KV ) + Q*T); // same as above but in half4x4
28232787
28242788 // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
28252789 simdgroup_half8x8 lo[D8];
@@ -3179,6 +3143,8 @@ kernel void kernel_flash_attn_ext(
31793143 }
31803144}
31813145
3146+ typedef decltype (kernel_flash_attn_ext<half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_t;
3147+
31823148template [[host_name(" kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1 , dequantize_f16, 64 >;
31833149template [[host_name(" kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1 , dequantize_f16, 80 >;
31843150template [[host_name(" kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1 , dequantize_f16, 96 >;
@@ -3223,7 +3189,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_
32233189
32243190// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
32253191template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &), short D, short Q = 1 , short C = 32 >
3226- kernel void flash_attn_ext_vec (
3192+ kernel void kernel_flash_attn_ext_vec (
32273193 device const char * q,
32283194 device const char * k,
32293195 device const char * v,
@@ -3548,22 +3514,21 @@ kernel void flash_attn_ext_vec(
35483514 }
35493515}
35503516
3551- // template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<128>;
3552- // template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<256>;
3517+ typedef decltype (kernel_flash_attn_ext_vec<half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
35533518
3554- template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <half4x4, 1 , dequantize_f16, 128 >;
3555- template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q4_0, 2 , dequantize_q4_0, 128 >;
3556- template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q4_1, 2 , dequantize_q4_1, 128 >;
3557- template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q5_0, 2 , dequantize_q5_0, 128 >;
3558- template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q5_1, 2 , dequantize_q5_1, 128 >;
3559- template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q8_0, 2 , dequantize_q8_0, 128 >;
3519+ template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <half4x4, 1 , dequantize_f16, 128 >;
3520+ template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q4_0, 2 , dequantize_q4_0, 128 >;
3521+ template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q4_1, 2 , dequantize_q4_1, 128 >;
3522+ template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q5_0, 2 , dequantize_q5_0, 128 >;
3523+ template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q5_1, 2 , dequantize_q5_1, 128 >;
3524+ template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q8_0, 2 , dequantize_q8_0, 128 >;
35603525
3561- template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <half4x4, 1 , dequantize_f16, 256 >;
3562- template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q4_0, 2 , dequantize_q4_0, 256 >;
3563- template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q4_1, 2 , dequantize_q4_1, 256 >;
3564- template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q5_0, 2 , dequantize_q5_0, 256 >;
3565- template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q5_1, 2 , dequantize_q5_1, 256 >;
3566- template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q8_0, 2 , dequantize_q8_0, 256 >;
3526+ template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <half4x4, 1 , dequantize_f16, 256 >;
3527+ template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q4_0, 2 , dequantize_q4_0, 256 >;
3528+ template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q4_1, 2 , dequantize_q4_1, 256 >;
3529+ template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q5_0, 2 , dequantize_q5_0, 256 >;
3530+ template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q5_1, 2 , dequantize_q5_1, 256 >;
3531+ template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q8_0, 2 , dequantize_q8_0, 256 >;
35673532
35683533template <typename T0, typename T1>
35693534kernel void kernel_cpy (
0 commit comments