Skip to content

WIP: ggml-cuda: Add bf16 cuda support to fattn (Flash Attention) #15261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added examples/cuda_p2p_bench.cpp
Empty file.
Empty file added examples/cuda_p2p_bench.h
Empty file.
Empty file.
9 changes: 5 additions & 4 deletions ggml/src/ggml-cuda/CMakeLists.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break CUDA 11 compatibility.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was some hackery due to an issue I was having getting the F32 kernel device code compiled and wasn't confident at the time cmake was respecting my -DCMAKE_CUDA_ARCHITECTURES flag. I reverted it instead of adjusting it since its not relevant to the change.

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
# 86 == RTX 3000, needs CUDA v11.1
# 89 == RTX 4000, needs CUDA v11.8
# 120 == Blackwell, compute capability 12.8.1 minimum
#
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
# XX-real == compile CUDA code as device code for this specific architecture
Expand All @@ -26,15 +27,15 @@ if (CUDAToolkit_FOUND)
set(CMAKE_CUDA_ARCHITECTURES "native")
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real;120-real")
else()
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;120-real")
endif()
else()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real;120-real")
else()
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;120-real")
endif()
endif()
endif()
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ typedef float2 dfloat2;
#define FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL

#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define BF16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
Expand Down Expand Up @@ -927,3 +931,7 @@ struct ggml_backend_cuda_context {
return pool(device);
}
};

static __device__ __forceinline__ __nv_bfloat16 ggml_cuda_bf16max(const __nv_bfloat16 a, const __nv_bfloat16 b) {
return __float2bfloat16(fmaxf((float)a, (float)b));
}
46 changes: 44 additions & 2 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,52 @@ static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k,

to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_cuda;
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
Comment on lines +652 to +697
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the order or types the same as their definition in the enum.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we talking about the ggml_type enum? None of the three types are in the same order. Should I update them all to be consistent?

My change was just copying the f32 type

$ diff bf16 f32
46,47d45
<         case GGML_TYPE_F32:
<             return convert_unary_cont_cuda<float>;

and adding the ggml_type_f32 case.

Happy to fix up the cases as appropriate, just let me know how you want to proceed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless of what order types were in before, when you do a switch like this, please use the order as declared in the ggml_type in ggml.h.

default:
return nullptr;
}
Expand Down
Loading