Skip to content

Commit 865d0b1

Browse files
committed
WIP: ggml-cuda: Add bf16 cuda support to fattn (Flash Attention)
1 parent 25ff6f7 commit 865d0b1

File tree

44 files changed

+967
-76
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+967
-76
lines changed

examples/cuda_p2p_bench.cpp

Whitespace-only changes.

examples/cuda_p2p_bench.h

Whitespace-only changes.

examples/test_nccl_sendrecv_bandwidth.cpp

Whitespace-only changes.

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
1515
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
1616
# 86 == RTX 3000, needs CUDA v11.1
1717
# 89 == RTX 4000, needs CUDA v11.8
18+
# 120 == Blackwell, compute capability 12.8.1 minimum
1819
#
1920
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
2021
# XX-real == compile CUDA code as device code for this specific architecture
@@ -26,15 +27,15 @@ if (CUDAToolkit_FOUND)
2627
set(CMAKE_CUDA_ARCHITECTURES "native")
2728
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
2829
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
29-
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
30+
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real;120-real")
3031
else()
31-
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
32+
set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;120-real")
3233
endif()
3334
else()
3435
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
35-
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
36+
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real;120-real")
3637
else()
37-
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
38+
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;120-real")
3839
endif()
3940
endif()
4041
endif()

ggml/src/ggml-cuda/common.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ typedef float2 dfloat2;
216216
#define FP16_AVAILABLE
217217
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
218218

219+
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
220+
#define BF16_AVAILABLE
221+
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
222+
219223
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
220224
#define FAST_FP16_AVAILABLE
221225
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
@@ -927,3 +931,7 @@ struct ggml_backend_cuda_context {
927931
return pool(device);
928932
}
929933
};
934+
935+
static __device__ __forceinline__ __nv_bfloat16 ggml_cuda_bf16max(const __nv_bfloat16 a, const __nv_bfloat16 b) {
936+
return __float2bfloat16(fmaxf((float)a, (float)b));
937+
}

ggml/src/ggml-cuda/convert.cu

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,10 +649,52 @@ static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k,
649649

650650
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
651651
switch (type) {
652-
case GGML_TYPE_F32:
653-
return convert_unary_cont_cuda<float>;
652+
case GGML_TYPE_Q4_0:
653+
return dequantize_row_q4_0_cuda;
654+
case GGML_TYPE_Q4_1:
655+
return dequantize_row_q4_1_cuda;
656+
case GGML_TYPE_Q5_0:
657+
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
658+
case GGML_TYPE_Q5_1:
659+
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
660+
case GGML_TYPE_Q8_0:
661+
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
662+
case GGML_TYPE_Q2_K:
663+
return dequantize_row_q2_K_cuda;
664+
case GGML_TYPE_Q3_K:
665+
return dequantize_row_q3_K_cuda;
666+
case GGML_TYPE_Q4_K:
667+
return dequantize_row_q4_K_cuda;
668+
case GGML_TYPE_Q5_K:
669+
return dequantize_row_q5_K_cuda;
670+
case GGML_TYPE_Q6_K:
671+
return dequantize_row_q6_K_cuda;
672+
case GGML_TYPE_IQ2_XXS:
673+
return dequantize_row_iq2_xxs_cuda;
674+
case GGML_TYPE_IQ2_XS:
675+
return dequantize_row_iq2_xs_cuda;
676+
case GGML_TYPE_IQ2_S:
677+
return dequantize_row_iq2_s_cuda;
678+
case GGML_TYPE_IQ3_XXS:
679+
return dequantize_row_iq3_xxs_cuda;
680+
case GGML_TYPE_IQ1_S:
681+
return dequantize_row_iq1_s_cuda;
682+
case GGML_TYPE_IQ1_M:
683+
return dequantize_row_iq1_m_cuda;
684+
case GGML_TYPE_IQ4_NL:
685+
return dequantize_row_iq4_nl_cuda;
686+
case GGML_TYPE_IQ4_XS:
687+
return dequantize_row_iq4_xs_cuda;
688+
case GGML_TYPE_IQ3_S:
689+
return dequantize_row_iq3_s_cuda;
690+
case GGML_TYPE_MXFP4:
691+
return dequantize_row_mxfp4_cuda;
654692
case GGML_TYPE_F16:
655693
return convert_unary_cont_cuda<half>;
694+
case GGML_TYPE_BF16:
695+
return convert_unary_cont_cuda<nv_bfloat16>;
696+
case GGML_TYPE_F32:
697+
return convert_unary_cont_cuda<float>;
656698
default:
657699
return nullptr;
658700
}

0 commit comments

Comments
 (0)