diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a23da57e3a1dc..578e3477f6d25 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -216,6 +216,10 @@ typedef float2 dfloat2; #define FP16_AVAILABLE #endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define BF16_AVAILABLE +#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #define FAST_FP16_AVAILABLE #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 @@ -927,3 +931,7 @@ struct ggml_backend_cuda_context { return pool(device); } }; + +static __device__ __forceinline__ __nv_bfloat16 ggml_cuda_bf16max(const __nv_bfloat16 a, const __nv_bfloat16 b) { + return __float2bfloat16(fmaxf((float)a, (float)b)); +} diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index e3beddbc1b23b..522615ad085ca 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -649,10 +649,52 @@ static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { switch (type) { - case GGML_TYPE_F32: - return convert_unary_cont_cuda; + case GGML_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cont_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cont_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cont_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_IQ2_XXS: + return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ2_XS: + return dequantize_row_iq2_xs_cuda; + case GGML_TYPE_IQ2_S: + return dequantize_row_iq2_s_cuda; + case GGML_TYPE_IQ3_XXS: + return dequantize_row_iq3_xxs_cuda; + case GGML_TYPE_IQ1_S: + return dequantize_row_iq1_s_cuda; + case GGML_TYPE_IQ1_M: + return dequantize_row_iq1_m_cuda; + case GGML_TYPE_IQ4_NL: + return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_IQ4_XS: + return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ3_S: + return dequantize_row_iq3_s_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda; + case GGML_TYPE_BF16: + return convert_unary_cont_cuda; + case GGML_TYPE_F32: + return convert_unary_cont_cuda; default: return nullptr; } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e46f0e2081bdf..cb5ed20748f7d 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -5,6 +5,7 @@ #include "vecdotq.cuh" #include +#include #define FATTN_KQ_STRIDE 256 #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. @@ -35,6 +36,8 @@ typedef void (* fattn_kernel_t)( typedef half (*vec_dot_KQ_f16_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +typedef nv_bfloat16 (*vec_dot_KQ_bf16_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); typedef float (*vec_dot_KQ_f32_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); @@ -238,16 +241,22 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const int v = get_int_b2(K_q8_0[ib].qs, iqs); - T Q_d; - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - Q_d = __low2half(Q_ds[k_KQ_0/warp_size]); - } else { + if constexpr (std::is_same_v) { const float2 * Q_ds = (const float2 *) Q_ds_v; - Q_d = Q_ds[k_KQ_0/warp_size].x; + nv_bfloat16 Q_d = Q_ds[k_KQ_0/warp_size].x; + nv_bfloat16 K_d = __float2bfloat16(__half2float(K_q8_0[ib].d)); + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_d, Q_d); + } else { + T Q_d; + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + Q_d = __low2half(Q_ds[k_KQ_0/warp_size]); + } else { + const float2 * Q_ds = (const float2 *) Q_ds_v; + Q_d = Q_ds[k_KQ_0/warp_size].x; + } + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d); } - - sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d); } return sum; @@ -262,7 +271,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( GGML_UNUSED(Q_ds_v); #ifdef FP16_AVAILABLE - if (std::is_same::value) { + if constexpr (std::is_same_v) { const half2 * Q_h2 = (const half2 *) Q_v; half2 sum2 = make_half2(0.0f, 0.0f); @@ -272,27 +281,63 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; - sum2 += K_ik * Q_h2[k_KQ_0/warp_size]; + sum2 = __hadd2(sum2, __hmul2(K_ik, Q_h2[k_KQ_0/warp_size])); } return __low2half(sum2) + __high2half(sum2); - } + } else #endif // FP16_AVAILABLE + { + const nv_bfloat162 * Q_bf16_2 = (const nv_bfloat162 *) Q_v; + float sum = 0.0f; - const float2 * Q_f2 = (const float2 *) Q_v; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { + const int k_KQ = k_KQ_0 + threadIdx.x; - float sum = 0.0f; + const half2 K_ik = K_h2[k_KQ]; + const nv_bfloat162 Q_ik = Q_bf16_2[k_KQ_0/warp_size]; + const float2 Q_f2 = __bfloat1622float2(Q_ik); + + sum += __half2float(__low2half(K_ik)) * Q_f2.x; + sum += __half2float(__high2half(K_ik)) * Q_f2.y; + } + + return (T)sum; + } +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_bf16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + const nv_bfloat162 * K_bf16_2 = (const nv_bfloat162 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + +#ifdef BF16_AVAILABLE + if (std::is_same::value) { + const nv_bfloat162 * Q_bf16_2 = (const nv_bfloat162 *) Q_v; + nv_bfloat162 sum2 = make_bfloat162(0.0f, 0.0f); +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { + const int k_KQ = k_KQ_0 + threadIdx.x; + sum2 = __hadd2(sum2, __hmul2(K_bf16_2[k_KQ], Q_bf16_2[k_KQ_0/warp_size])); + } + return __low2bfloat16(sum2) + __high2bfloat16(sum2); + } +#endif // BF16_AVAILABLE + const float2 * Q_f2 = (const float2 *) Q_v; + float sum = 0.0f; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[k_KQ]; - sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x; - sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y; + const nv_bfloat162 K_ik = K_bf16_2[k_KQ]; + const float2 K_f2 = __bfloat1622float2(K_ik); + sum += K_f2.x * Q_f2[k_KQ_0/warp_size].x; + sum += K_f2.y * Q_f2[k_KQ_0/warp_size].y; } - - return sum; + return (T)sum; } template @@ -340,6 +385,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( } typedef half (*dequantize_1_f16_t)(const void *, const int64_t); +typedef nv_bfloat16 (*dequantize_1_bf16_t)(const void *, const int64_t); typedef float (*dequantize_1_f32_t)(const void *, const int64_t); template @@ -350,17 +396,11 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ const int iqs = i % (QK4_0/2); const int shift = (i % QK4_0) / (QK4_0/2); - const T d = x[ib].d; + const half d = x[ib].d; const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F) - 8; -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - return ((half) d)*((half) q); - } -#endif // FP16_AVAILABLE - - return ((float) d)*((float) q); + return (T)(__half2float(d) * (float)q); } template @@ -375,13 +415,8 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F); -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - return __low2half(dm)*((half) q) + __high2half(dm); - } -#endif // FP16_AVAILABLE - - return __low2float(dm)*((float) q) + __high2float(dm); + const float result = __low2float(dm)*((float) q) + __high2float(dm); + return (T)result; } template @@ -393,20 +428,14 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ const int iqs = i % (QK5_0/2); const int shift = (i % QK5_0) / (QK5_0/2); - const T d = x[ib].d; + const half d = x[ib].d; const int ql0 = x[ib].qs[iqs]; const int qh0 = get_int_b2(x[ib].qh, 0); const int ql = ((ql0 >> (4*shift)) & 0x0F); const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh) - 16; -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - return ((half) d)*((half) q); - } -#endif // FP16_AVAILABLE - - return ((float) d)*((float) q); + return (T)(__half2float(d) * (float)q); } template @@ -425,13 +454,8 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh); -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - return __low2half(dm)*((half) q) + __high2half(dm); - } -#endif // FP16_AVAILABLE - - return __low2float(dm)*((float) q) + __high2float(dm); + const float result = __low2float(dm)*((float) q) + __high2float(dm); + return (T)result; } template @@ -441,23 +465,20 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ const int64_t ib = i / QK8_0; const int iqs = i % QK8_0; - const T d = x[ib].d; - const int q = x[ib].qs[iqs]; + const half d = x[ib].d; + const int q = x[ib].qs[iqs]; -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - return ((half) d)*((half) q); - } -#endif // FP16_AVAILABLE - - return ((float) d)*((float) q); + return (T)(__half2float(d) * (float)q); } template static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { const half * x = (const half *) vx; - - return x[i]; + if constexpr (std::is_same_v) { + return __float2bfloat16(__half2float(x[i])); + } else { + return (T)x[i]; + } } template @@ -468,6 +489,29 @@ constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + type_K == GGML_TYPE_BF16 ? vec_dot_fattn_vec_KQ_f16 : + nullptr; +} + +template +static __device__ __forceinline__ T dequantize_1_bf16(const void * __restrict__ vx, const int64_t i) { + const nv_bfloat16 * x = (const nv_bfloat16 *) vx; + if constexpr (std::is_same_v) { + return __float2half(__bfloat162float(x[i])); + } else { + return (T)x[i]; + } +} + +template +constexpr __device__ vec_dot_KQ_bf16_t get_vec_dot_KQ_bf16(ggml_type type_K) { + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + type_K == GGML_TYPE_BF16 ? vec_dot_fattn_vec_KQ_bf16 : nullptr; } @@ -479,6 +523,7 @@ constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + type_K == GGML_TYPE_BF16 ? vec_dot_fattn_vec_KQ_bf16 : nullptr; } @@ -489,6 +534,18 @@ constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) { type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : type_V == GGML_TYPE_F16 ? dequantize_1_f16 : + type_V == GGML_TYPE_BF16 ? dequantize_1_bf16 : + nullptr; +} + +constexpr __device__ dequantize_1_bf16_t get_dequantize_1_bf16(ggml_type type_V) { + return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : + type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : + type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : + type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : + type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : + type_V == GGML_TYPE_F16 ? dequantize_1_f16 : + type_V == GGML_TYPE_BF16 ? dequantize_1_bf16 : nullptr; } @@ -645,7 +702,7 @@ static __global__ void flash_attn_stream_k_fixup( template // D == head size #if !defined(GGML_USE_HIP) __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) +#endif // !(defined(GGML_USE_HIP)) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, diff --git a/ggml/src/ggml-cuda/fattn-vec-bf16.cuh b/ggml/src/ggml-cuda/fattn-vec-bf16.cuh new file mode 100644 index 0000000000000..a9cd9c0f71799 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-vec-bf16.cuh @@ -0,0 +1,523 @@ +#include "common.cuh" +#include "fattn-common.cuh" + +#ifdef __cplusplus +// Template definition must come before macro usage, and must NOT be in extern "C" +template +void ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +template +void ggml_cuda_flash_attn_ext_vec_bf16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +#endif + +// Device exponential for bfloat16 with clamping to avoid overflow/underflow +#ifndef BF16_MAX +#define BF16_MAX 3.38953139e+38f +#endif +#ifndef BF16_MIN +#define BF16_MIN -3.38953139e+38f +#endif + +__device__ __forceinline__ __nv_bfloat16 hexp_bf16(float x) { + float val = expf(x); + val = fminf(fmaxf(val, BF16_MIN), BF16_MAX); + return __float2bfloat16(val); +} + +// Currenlty llvm with the amdgcn target dose not support unrolling loops +// that contain a break that can not be resolved at compile time. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template // D == head size +#ifndef GGML_USE_HIP +__launch_bounds__(D, 1) +#endif // GGML_USE_HIP +static __global__ void flash_attn_vec_ext_bf16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + + constexpr vec_dot_KQ_bf16_t vec_dot_KQ = get_vec_dot_KQ_bf16(type_K); + constexpr bool Q_q8_1 = type_K != GGML_TYPE_BF16; + constexpr dequantize_1_bf16_t dequantize_1_v = get_dequantize_1_bf16(type_V); + + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); + + const __nv_bfloat16 * maskh = (const __nv_bfloat16 *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); + + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + const __nv_bfloat16 slopeh = __float2bfloat16(slopef); + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = D / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ __nv_bfloat16 KQ[ncols*D]; + __nv_bfloat162 * KQ2 = (__nv_bfloat162 *) KQ; + + __nv_bfloat16 kqmax[ncols]; + __nv_bfloat16 kqsum[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax[j] = -BF16_MAX; + kqsum[j] = 0.0f; + } + + __shared__ __nv_bfloat16 kqmax_shared[ncols][WARP_SIZE]; + __shared__ __nv_bfloat16 kqsum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + kqmax_shared[j][threadIdx.x] = -BF16_MAX; + kqsum_shared[j][threadIdx.x] = 0.0f; + } + } + + __shared__ __nv_bfloat16 maskh_shared[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*D + tid] = 0.0f; + } + + __syncthreads(); + + // Convert Q to __nv_bfloat162 (bf16 K) or q8_1 (quantized K) and store in registers: + __nv_bfloat162 Q_h2[ncols][D/(2*WARP_SIZE)]; + int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; + __nv_bfloat162 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + if (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + __nv_bfloat162 * tmp_q_ds = (__nv_bfloat162 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 2 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tmp_q_i32[i] = 0; + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_bfloat162(0.0f, 0.0f); + } + continue; + } + + const float * Q_f = (const float *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + quantize_q8_1_to_shared<__nv_bfloat162>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + __nv_bfloat162 * tmp_q_ds = (__nv_bfloat162 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; + Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_h2[j][i0/WARP_SIZE] = make_bfloat162(scale, scale) * make_bfloat162(tmp.x, tmp.y); + } + } + } + + + #pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ[j*D + tid] = -BF16_MAX; + } + __syncthreads(); + + __nv_bfloat162 VKQ[ncols] = {make_bfloat162(0.0f, 0.0f)}; + + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + K += blockIdx.y*D * nb11; + V += blockIdx.y*D * nb21; + maskh += blockIdx.y*D; + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, + /* Increment pointers after each loop: */ + K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { + + // Calculate KQ tile and keep track of new maximum KQ values: + + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; + } + __syncthreads(); + } + + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, + // see https://github.com/ggerganov/llama.cpp/pull/7061 . + // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). + __nv_bfloat16 kqmax_new = kqmax[0]; + __nv_bfloat16 kqmax_new_arr[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax_new_arr[j] = kqmax[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + __nv_bfloat16 sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum((float)sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + + sum += maskh_shared[j*D + i_KQ]; + + if (ncols == 1) { + kqmax_new = ggml_cuda_bf16max(kqmax_new, sum); + } else { + kqmax_new_arr[j] = ggml_cuda_bf16max(kqmax_new_arr[j], sum); + } + + if (threadIdx.x == 0) { + KQ[j*D + i_KQ] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + __nv_bfloat16 kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; + + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = kqmax_new_j; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + __nv_bfloat16 kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const __nv_bfloat16 KQ_max_scale = hexp_bf16(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const __nv_bfloat16 val = hexp_bf16(KQ[j*D + tid] - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale + val; + KQ[j*D + tid] = val; + + VKQ[j] *= __bfloat162bfloat162(KQ_max_scale); + } + + __syncthreads(); + +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } + + __nv_bfloat162 V_k; + reinterpret_cast<__nv_bfloat16&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); + reinterpret_cast<__nv_bfloat16&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; + } + } + + __syncthreads(); + } + + if (sinksf && blockIdx.y == 0) { + const __nv_bfloat16 sink = __float2bfloat16(sinksf[head]); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = fmaxf((float)kqmax[j], (float)sink); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + __nv_bfloat16 kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const __nv_bfloat16 KQ_max_scale = hexp_bf16(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const __nv_bfloat16 val = hexp_bf16(sink - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale; + + if (tid == 0) { + kqsum[j] += val; + } + + VKQ[j] *= __bfloat162bfloat162(KQ_max_scale); + } + + __syncthreads(); + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqsum[j] = warp_reduce_sum((float)kqsum[j]); + if (threadIdx.x == 0) { + kqsum_shared[j][threadIdx.y] = kqsum[j]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 2 && ic0 + j_VKQ >= ne01) { + break; + } + + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; + kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]); + + float dst_val = (__low2bfloat16(VKQ[j_VKQ]) + __high2bfloat16(VKQ[j_VKQ])); + if (gridDim.y == 1) { + const float inv_kqsum = 1.0f / (float)kqsum[j_VKQ]; + dst_val *= inv_kqsum; + } + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; + } + + if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2((float)kqmax[tid], (float)kqsum[tid]); + } +#else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); + NO_DEVICE_CODE; +#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +template +void ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_bf16; + GGML_ASSERT(fattn_kernel != nullptr); + constexpr bool need_f16_K = false; + constexpr bool need_f16_V = false; + constexpr size_t nbytes_shared = 0; + + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); +} + +template +void ggml_cuda_flash_attn_ext_vec_bf16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K->type == type_K); + GGML_ASSERT(V->type == type_V); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { + constexpr int cols_per_block = 1; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ctx, dst); + } + return; + } + + if (Q->ne[1] == 2) { + constexpr int cols_per_block = 2; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ctx, dst); + } + return; + } + + if (Q->ne[1] <= 4) { + constexpr int cols_per_block = 4; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ctx, dst); + } + return; + } + + constexpr int cols_per_block = 8; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_bf16_case_impl(ctx, dst); + } +} + +#define DECL_FATTN_VEC_BF16_CASE(D, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_bf16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) + +DECL_FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16); + +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); + +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); + +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); + +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); + +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); + +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); + +DECL_FATTN_VEC_BF16_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 22e90d0e7b316..1fd96a827d701 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -4,6 +4,7 @@ #include "fattn-tile-f16.cuh" #include "fattn-tile-f32.cuh" #include "fattn-vec-f16.cuh" +#include "fattn-vec-bf16.cuh" #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" #include "fattn.cuh" @@ -193,6 +194,80 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg on_no_fattn_vec_case(Q->ne[0]); } +#define FATTN_VEC_BF16_CASE(D, type_K, type_V) \ + if ((Q->ne[0]) == (D) && (K->type) == (type_K) && (V->type) == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_bf16_case(ctx, dst); \ + return; \ + } + +static inline void ggml_cuda_flash_attn_ext_vec_bf16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * Q = dst->src[0]; + ggml_tensor * K = dst->src[1]; + ggml_tensor * V = dst->src[2]; + +#ifdef GGML_CUDA_FA_ALL_QUANTS + FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0) + FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1) + FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0) + FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1) + FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0) + FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16) + + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0) + + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1) + + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0) + + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1) + + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0) + + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16) +#else + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_BF16_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_BF16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16) +#endif // GGML_CUDA_FA_ALL_QUANTS + + on_no_fattn_vec_case(Q->ne[0]); +} + #define FATTN_VEC_F32_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ @@ -280,6 +355,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + // Corrected dispatch logic: check for bf16 first. + if (K->type == GGML_TYPE_BF16 || V->type == GGML_TYPE_BF16) { + ggml_cuda_flash_attn_ext_vec_bf16(ctx, dst); + return; + } + #if defined(GGML_HIP_ROCWMMA_FATTN) if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d9110491ec78c..5b13150998845 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3539,16 +3539,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } - if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { - return false; - } - if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { + if (op->src[0]->ne[0] == 64) { return true; } if (op->src[0]->ne[0] == 128) { return true; } - if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { + if (op->src[0]->ne[0] == 256) { return true; } if (op->src[3] && op->src[3]->ne[2] != 1) { diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index d058504cd6cc0..74c257ce9fdab 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -328,12 +328,16 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16); + GGML_ASSERT(src0->type == dst->type); cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + const __nv_bfloat16 * src0_bf16 = (const __nv_bfloat16 *)src0->data; + __nv_bfloat16 * dst_bf16 = (__nv_bfloat16 *)dst->data; + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16); GGML_ASSERT(src0->type == dst->type); const int64_t ne00 = src0->ne[0]; // head dims @@ -399,6 +403,10 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) rope_neox_cuda( (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + } else if (src0->type == GGML_TYPE_BF16) { + rope_neox_cuda( + (const __nv_bfloat16 *) src0_bf16, (__nv_bfloat16 *) dst_bf16, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else { GGML_ABORT("fatal error"); } @@ -411,6 +419,10 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) rope_multi_cuda( (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + } else if (src0->type == GGML_TYPE_BF16) { + rope_multi_cuda( + (const __nv_bfloat16 *) src0_bf16, (__nv_bfloat16 *) dst_bf16, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } @@ -423,6 +435,10 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) rope_vision_cuda( (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + } else if (src0->type == GGML_TYPE_BF16) { + rope_vision_cuda( + (const __nv_bfloat16 *) src0_bf16, (__nv_bfloat16 *) dst_bf16, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } @@ -435,6 +451,10 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) rope_norm_cuda( (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + } else if (src0->type == GGML_TYPE_BF16) { + rope_norm_cuda( + (const __nv_bfloat16 *) src0_bf16, (__nv_bfloat16 *) dst_bf16, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-bf16.cu new file mode 100644 index 0000000000000..7a8370ac91df7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-f16.cu new file mode 100644 index 0000000000000..71440fb617aac --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-f16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q4_0.cu new file mode 100644 index 0000000000000..0f334ab8739de --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q4_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q4_1.cu new file mode 100644 index 0000000000000..7df004847069f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q4_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q5_0.cu new file mode 100644 index 0000000000000..b3a82fc92c71a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q5_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q5_1.cu new file mode 100644 index 0000000000000..e7ad4a4b3b6f0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q5_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q8_0.cu new file mode 100644 index 0000000000000..5170c2c470a5f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-bf16-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-bf16.cu new file mode 100644 index 0000000000000..410aeda7aa554 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-bf16.cu new file mode 100644 index 0000000000000..b9663f7622f35 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-bf16.cu new file mode 100644 index 0000000000000..c1180f0168813 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-bf16.cu new file mode 100644 index 0000000000000..46d873db49b60 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-bf16.cu new file mode 100644 index 0000000000000..8168ad3877398 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-bf16.cu new file mode 100644 index 0000000000000..91d4494b63fba --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-bf16-bf16.cu new file mode 100644 index 0000000000000..6baff7ae38e63 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-bf16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-bf16-bf16.cu new file mode 100644 index 0000000000000..a76befba29557 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-bf16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-bf16.cu new file mode 100644 index 0000000000000..446fba766873d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-bf16.cu new file mode 100644 index 0000000000000..173da40b8c7f3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-f16.cu new file mode 100644 index 0000000000000..f0f4394be27c5 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-f16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q4_0.cu new file mode 100644 index 0000000000000..8b8417ddb0be4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q4_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q4_1.cu new file mode 100644 index 0000000000000..328366460ada1 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q4_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q5_0.cu new file mode 100644 index 0000000000000..3a4a1cc08db04 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q5_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q5_1.cu new file mode 100644 index 0000000000000..50bbcf0b1d9ee --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q5_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q8_0.cu new file mode 100644 index 0000000000000..0babcd8ee76b0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-bf16-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-bf16.cu new file mode 100644 index 0000000000000..3dd5ba2086eb7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-bf16.cu new file mode 100644 index 0000000000000..74ce58d72846b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-bf16.cu new file mode 100644 index 0000000000000..63731b66de917 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-bf16.cu new file mode 100644 index 0000000000000..6918d6c2eb830 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-bf16.cu new file mode 100644 index 0000000000000..71b6db00fb41f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-bf16.cu new file mode 100644 index 0000000000000..2654a17869cfe --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-bf16-bf16.cu new file mode 100644 index 0000000000000..e72357bc5a981 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-bf16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-bf16-bf16.cu new file mode 100644 index 0000000000000..07ef310e7fe1d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-bf16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-bf16.cu new file mode 100644 index 0000000000000..e4755e176ee44 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-bf16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 3428113dc8fd2..e34d733e7d868 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_F16"] +TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_F16", "GGML_TYPE_BF16"] SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -42,6 +42,8 @@ def get_short_name(long_quant_name): def get_head_sizes(type_k, type_v): if type_k == "GGML_TYPE_F16" and type_v == "GGML_TYPE_F16": return [64, 128, 256] + if type_k == "GGML_TYPE_BF16" and type_v == "GGML_TYPE_BF16": + return [64, 128, 256] if type_k == "GGML_TYPE_F16": return [64, 128] return [128]