Skip to content

WIP: ggml-cuda: Add bf16 cuda support to fattn (Flash Attention) #15261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ typedef float2 dfloat2;
#define FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL

#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define BF16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
Expand Down Expand Up @@ -927,3 +931,7 @@ struct ggml_backend_cuda_context {
return pool(device);
}
};

static __device__ __forceinline__ __nv_bfloat16 ggml_cuda_bf16max(const __nv_bfloat16 a, const __nv_bfloat16 b) {
return __float2bfloat16(fmaxf((float)a, (float)b));
}
46 changes: 44 additions & 2 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,52 @@ static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k,

to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_cuda;
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
Comment on lines +652 to +697
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the order or types the same as their definition in the enum.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we talking about the ggml_type enum? None of the three types are in the same order. Should I update them all to be consistent?

My change was just copying the f32 type

$ diff bf16 f32
46,47d45
<         case GGML_TYPE_F32:
<             return convert_unary_cont_cuda<float>;

and adding the ggml_type_f32 case.

Happy to fix up the cases as appropriate, just let me know how you want to proceed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless of what order types were in before, when you do a switch like this, please use the order as declared in the ggml_type in ggml.h.

default:
return nullptr;
}
Expand Down
179 changes: 118 additions & 61 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "vecdotq.cuh"

#include <cstdint>
#include <cuda_bf16.h>

#define FATTN_KQ_STRIDE 256
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
Expand Down Expand Up @@ -35,6 +36,8 @@ typedef void (* fattn_kernel_t)(

typedef half (*vec_dot_KQ_f16_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
typedef nv_bfloat16 (*vec_dot_KQ_bf16_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
typedef float (*vec_dot_KQ_f32_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);

Expand Down Expand Up @@ -238,16 +241,22 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(

const int v = get_int_b2(K_q8_0[ib].qs, iqs);

T Q_d;
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
} else {
if constexpr (std::is_same_v<T, nv_bfloat16>) {
const float2 * Q_ds = (const float2 *) Q_ds_v;
Q_d = Q_ds[k_KQ_0/warp_size].x;
nv_bfloat16 Q_d = Q_ds[k_KQ_0/warp_size].x;
nv_bfloat16 K_d = __float2bfloat16(__half2float(K_q8_0[ib].d));
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_d, Q_d);
} else {
T Q_d;
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
} else {
const float2 * Q_ds = (const float2 *) Q_ds_v;
Q_d = Q_ds[k_KQ_0/warp_size].x;
}
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
}

sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
}

return sum;
Expand All @@ -262,7 +271,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
GGML_UNUSED(Q_ds_v);

#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
if constexpr (std::is_same_v<T, half>) {
const half2 * Q_h2 = (const half2 *) Q_v;

half2 sum2 = make_half2(0.0f, 0.0f);
Expand All @@ -272,27 +281,63 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
const int k_KQ = k_KQ_0 + threadIdx.x;

const half2 K_ik = K_h2[k_KQ];
sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
sum2 = __hadd2(sum2, __hmul2(K_ik, Q_h2[k_KQ_0/warp_size]));
}

return __low2half(sum2) + __high2half(sum2);
}
} else
#endif // FP16_AVAILABLE
{
const nv_bfloat162 * Q_bf16_2 = (const nv_bfloat162 *) Q_v;
float sum = 0.0f;

const float2 * Q_f2 = (const float2 *) Q_v;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;

float sum = 0.0f;
const half2 K_ik = K_h2[k_KQ];
const nv_bfloat162 Q_ik = Q_bf16_2[k_KQ_0/warp_size];
const float2 Q_f2 = __bfloat1622float2(Q_ik);

sum += __half2float(__low2half(K_ik)) * Q_f2.x;
sum += __half2float(__high2half(K_ik)) * Q_f2.y;
}

return (T)sum;
}
}

template <typename T, int D, int warp_size>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_bf16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const nv_bfloat162 * K_bf16_2 = (const nv_bfloat162 *) K_c;
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);

#ifdef BF16_AVAILABLE
if (std::is_same<T, nv_bfloat16>::value) {
const nv_bfloat162 * Q_bf16_2 = (const nv_bfloat162 *) Q_v;
nv_bfloat162 sum2 = make_bfloat162(0.0f, 0.0f);
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
sum2 = __hadd2(sum2, __hmul2(K_bf16_2[k_KQ], Q_bf16_2[k_KQ_0/warp_size]));
}
return __low2bfloat16(sum2) + __high2bfloat16(sum2);
}
#endif // BF16_AVAILABLE

const float2 * Q_f2 = (const float2 *) Q_v;
float sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;

const half2 K_ik = K_h2[k_KQ];
sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
const nv_bfloat162 K_ik = K_bf16_2[k_KQ];
const float2 K_f2 = __bfloat1622float2(K_ik);
sum += K_f2.x * Q_f2[k_KQ_0/warp_size].x;
sum += K_f2.y * Q_f2[k_KQ_0/warp_size].y;
}

return sum;
return (T)sum;
}

template <typename Tds>
Expand Down Expand Up @@ -340,6 +385,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
}

typedef half (*dequantize_1_f16_t)(const void *, const int64_t);
typedef nv_bfloat16 (*dequantize_1_bf16_t)(const void *, const int64_t);
typedef float (*dequantize_1_f32_t)(const void *, const int64_t);

template <typename T>
Expand All @@ -350,17 +396,11 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
const int iqs = i % (QK4_0/2);
const int shift = (i % QK4_0) / (QK4_0/2);

const T d = x[ib].d;
const half d = x[ib].d;
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;

#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
#endif // FP16_AVAILABLE

return ((float) d)*((float) q);
return (T)(__half2float(d) * (float)q);
}

template <typename T>
Expand All @@ -375,13 +415,8 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F);

#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
#endif // FP16_AVAILABLE

return __low2float(dm)*((float) q) + __high2float(dm);
const float result = __low2float(dm)*((float) q) + __high2float(dm);
return (T)result;
}

template <typename T>
Expand All @@ -393,20 +428,14 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
const int iqs = i % (QK5_0/2);
const int shift = (i % QK5_0) / (QK5_0/2);

const T d = x[ib].d;
const half d = x[ib].d;
const int ql0 = x[ib].qs[iqs];
const int qh0 = get_int_b2(x[ib].qh, 0);
const int ql = ((ql0 >> (4*shift)) & 0x0F);
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh) - 16;

#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
#endif // FP16_AVAILABLE

return ((float) d)*((float) q);
return (T)(__half2float(d) * (float)q);
}

template <typename T>
Expand All @@ -425,13 +454,8 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh);

#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
#endif // FP16_AVAILABLE

return __low2float(dm)*((float) q) + __high2float(dm);
const float result = __low2float(dm)*((float) q) + __high2float(dm);
return (T)result;
}

template <typename T>
Expand All @@ -441,23 +465,20 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
const int64_t ib = i / QK8_0;
const int iqs = i % QK8_0;

const T d = x[ib].d;
const int q = x[ib].qs[iqs];
const half d = x[ib].d;
const int q = x[ib].qs[iqs];

#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
#endif // FP16_AVAILABLE

return ((float) d)*((float) q);
return (T)(__half2float(d) * (float)q);
}

template <typename T>
static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
const half * x = (const half *) vx;

return x[i];
if constexpr (std::is_same_v<T, nv_bfloat16>) {
return __float2bfloat16(__half2float(x[i]));
} else {
return (T)x[i];
}
}

template <int D, int warp_size = WARP_SIZE>
Expand All @@ -468,6 +489,29 @@ constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
type_K == GGML_TYPE_BF16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
nullptr;
}

template <typename T>
static __device__ __forceinline__ T dequantize_1_bf16(const void * __restrict__ vx, const int64_t i) {
const nv_bfloat16 * x = (const nv_bfloat16 *) vx;
if constexpr (std::is_same_v<T, half>) {
return __float2half(__bfloat162float(x[i]));
} else {
return (T)x[i];
}
}

template <int D, int warp_size = WARP_SIZE>
constexpr __device__ vec_dot_KQ_bf16_t get_vec_dot_KQ_bf16(ggml_type type_K) {
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<nv_bfloat16, D, warp_size> :
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<nv_bfloat16, D, warp_size> :
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<nv_bfloat16, D, warp_size> :
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<nv_bfloat16, D, warp_size> :
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<nv_bfloat16, D, warp_size> :
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<nv_bfloat16, D, warp_size> :
type_K == GGML_TYPE_BF16 ? vec_dot_fattn_vec_KQ_bf16<nv_bfloat16, D, warp_size> :
nullptr;
}

Expand All @@ -479,6 +523,7 @@ constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
type_K == GGML_TYPE_BF16 ? vec_dot_fattn_vec_KQ_bf16<float, D, warp_size> :
nullptr;
}

Expand All @@ -489,6 +534,18 @@ constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
type_V == GGML_TYPE_BF16 ? dequantize_1_bf16<half> :
nullptr;
}

constexpr __device__ dequantize_1_bf16_t get_dequantize_1_bf16(ggml_type type_V) {
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<nv_bfloat16> :
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<nv_bfloat16> :
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<nv_bfloat16> :
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<nv_bfloat16> :
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<nv_bfloat16> :
type_V == GGML_TYPE_F16 ? dequantize_1_f16<nv_bfloat16> :
type_V == GGML_TYPE_BF16 ? dequantize_1_bf16<nv_bfloat16> :
nullptr;
}

Expand Down Expand Up @@ -645,7 +702,7 @@ static __global__ void flash_attn_stream_k_fixup(
template<int D> // D == head size
#if !defined(GGML_USE_HIP)
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP)
#endif // !(defined(GGML_USE_HIP))
static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts,
const float2 * __restrict__ VKQ_meta,
Expand Down
Loading