From 206d22bd40e76cb7bab1ba047b78136f97db8a44 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Sun, 23 Feb 2025 01:56:19 +0800 Subject: [PATCH 01/13] Add GGML_HIP_ROCWMMA_FATTN and rocwmma header check --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-hip/CMakeLists.txt | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index fc5eac151b90c..a79b93a6b63db 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -157,6 +157,7 @@ option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp on option(GGML_HIP "ggml: use HIP" OFF) option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) +option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index f4a4683639fab..dfa39d1758d6e 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -39,6 +39,12 @@ endif() find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (GGML_HIP_ROCWMMA_FATTN) + CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA) + if (NOT ${FOUND_ROCWMMA}) + message(FATAL_ERROR "rocwmma is not found") + endif() +endif() if (${hip_VERSION} VERSION_LESS 5.5) message(FATAL_ERROR "At least ROCM/HIP V5.5 is required") @@ -107,6 +113,10 @@ if (GGML_HIP_NO_VMM) add_compile_definitions(GGML_HIP_NO_VMM) endif() +if (GGML_HIP_ROCWMMA_FATTN) + add_compile_definitions(GGML_HIP_ROCWMMA_FATTN) +endif() + if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-hip PRIVATE hip::device) From 02369da4960271c49ab7f7c2b44468a414cbe911 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Sun, 23 Feb 2025 03:47:43 +0800 Subject: [PATCH 02/13] Add rocWMMA support --- ggml/src/ggml-cuda/common.cuh | 10 ++++- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 65 +++++++++++++++++++++++++++- ggml/src/ggml-cuda/fattn.cu | 7 +++ 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 7e99838c09261..fe32473653131 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -196,6 +196,10 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) +#define FP16_MMA_AVAILABLE +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) + #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING @@ -223,12 +227,14 @@ static bool fast_fp16_hardware_available(const int cc) { // Any FP16 tensor core instructions are available for ggml code. static bool fp16_mma_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA || + cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3; } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; + return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA || + cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3; } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index de38470abec45..26eb9c7885142 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -7,7 +7,12 @@ #include "fattn-wmma-f16.cuh" #ifdef FP16_MMA_AVAILABLE +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #include +#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) +#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers +#include +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: @@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -68,11 +73,19 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) typedef nvcuda::wmma::fragment frag_a_K; typedef nvcuda::wmma::fragment frag_a_V; typedef nvcuda::wmma::fragment frag_b; typedef nvcuda::wmma::fragment frag_c_KQ; typedef nvcuda::wmma::fragment frag_c_VKQ; +#else + typedef rocwmma::fragment frag_a_K; + typedef rocwmma::fragment frag_a_V; + typedef rocwmma::fragment frag_b; + typedef rocwmma::fragment frag_c_KQ; + typedef rocwmma::fragment frag_c_VKQ; +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -162,7 +175,11 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); +#else + rocwmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) } } @@ -176,20 +193,36 @@ static __global__ void flash_attn_ext_f16( frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); +#else + rocwmma::fill_fragment(KQ_c[j], static_cast(0.0f)); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) } #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#else + rocwmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); +#else + rocwmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) } } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); +#else + rocwmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, rocwmma::mem_col_major); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) } } @@ -308,10 +341,17 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { const int k = k0 + (threadIdx.y % VKQ_ratio)*16; +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], KQ + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); +#else + rocwmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) } } @@ -320,7 +360,11 @@ static __global__ void flash_attn_ext_f16( for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); +#else + rocwmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast(0.0f)); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) } #pragma unroll @@ -328,10 +372,18 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#else + rocwmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); +#else + rocwmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) } } } @@ -343,10 +395,17 @@ static __global__ void flash_attn_ext_f16( for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) nvcuda::wmma::store_matrix_sync( KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); +#else + rocwmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, rocwmma::mem_col_major); +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) } } @@ -425,7 +484,7 @@ static __global__ void flash_attn_ext_f16( } #else NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)) } constexpr int get_max_power_of_2(int x) { @@ -574,6 +633,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { constexpr int cols_per_block = 8; switch (Q->ne[0]) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) case 64: ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); break; @@ -586,6 +646,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten case 256: ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) default: GGML_ABORT("fatal error"); break; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index b1becccb4de72..c5726691a6774 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -254,6 +254,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= GGML_CUDA_CC_OFFSET_AMD) { +#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) + if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + return; + } +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) + if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { From 419f1ea9cb8da4ca4135ca0807ed23d92b8ff4f7 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Sun, 23 Feb 2025 18:18:12 +0800 Subject: [PATCH 03/13] Update ggml/src/ggml-hip/CMakeLists.txt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-hip/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index e9c835ce2ab21..e3762649fd275 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -42,7 +42,7 @@ find_package(rocblas REQUIRED) if (GGML_HIP_ROCWMMA_FATTN) CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA) if (NOT ${FOUND_ROCWMMA}) - message(FATAL_ERROR "rocwmma is not found") + message(FATAL_ERROR "rocwmma has not been found") endif() endif() From 828577a9d632172770558308cc48292d8cade539 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Sun, 23 Feb 2025 18:17:52 +0800 Subject: [PATCH 04/13] Move comments to reduce confusion. --- ggml/src/ggml-cuda/fattn.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index c5726691a6774..c4cabe4a7ef07 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -252,7 +252,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= GGML_CUDA_CC_OFFSET_AMD) { #if defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) { @@ -261,6 +260,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } #endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) + // On AMD the tile kernels perform poorly, use the vec kernel instead: if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { From 9d27c38b02b060433a0818571c63a598cb3da7d9 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Sun, 23 Feb 2025 18:37:24 +0800 Subject: [PATCH 05/13] Use namespace alias `wmma` instead of lots of ifdefs. --- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 88 ++++++---------------------- 1 file changed, 18 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 42825413ee9f4..6b63150850699 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -9,9 +9,11 @@ #ifdef FP16_MMA_AVAILABLE #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #include +namespace wmma = nvcuda::wmma; #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) #undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers #include +namespace wmma = rocwmma; #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #endif // FP16_MMA_AVAILABLE @@ -73,19 +75,11 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - typedef nvcuda::wmma::fragment frag_a_K; - typedef nvcuda::wmma::fragment frag_a_V; - typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c_KQ; - typedef nvcuda::wmma::fragment frag_c_VKQ; -#else - typedef rocwmma::fragment frag_a_K; - typedef rocwmma::fragment frag_a_V; - typedef rocwmma::fragment frag_b; - typedef rocwmma::fragment frag_c_KQ; - typedef rocwmma::fragment frag_c_VKQ; -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -175,11 +169,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); -#else - rocwmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); } } @@ -193,36 +183,20 @@ static __global__ void flash_attn_ext_f16( frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); -#else - rocwmma::fill_fragment(KQ_c[j], static_cast(0.0f)); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + wmma::fill_fragment(KQ_c[j], static_cast(0.0f)); } #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); -#else - rocwmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); -#else - rocwmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); -#else - rocwmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, rocwmma::mem_col_major); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major); } } @@ -341,17 +315,10 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { const int k = k0 + (threadIdx.y % VKQ_ratio)*16; -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::load_matrix_sync( + wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], KQ + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); -#else - rocwmma::load_matrix_sync( - KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, - kqar*kqs_padded); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) } } @@ -360,11 +327,7 @@ static __global__ void flash_attn_ext_f16( for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); -#else - rocwmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast(0.0f)); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast(0.0f)); } #pragma unroll @@ -372,18 +335,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); -#else - rocwmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); -#else - rocwmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } @@ -395,17 +350,10 @@ static __global__ void flash_attn_ext_f16( for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - nvcuda::wmma::store_matrix_sync( + wmma::store_matrix_sync( KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); -#else - rocwmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), - VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, rocwmma::mem_col_major); -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + D_padded, wmma::mem_col_major); } } From 19272bfaa4db1adb29e5b04b4d22fe70443a4e15 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Sun, 23 Feb 2025 19:20:43 +0800 Subject: [PATCH 06/13] Fix: FP16_MMA_AVAILABLE should not be checked in host code. --- ggml/src/ggml-cuda/fattn.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index c4cabe4a7ef07..1437019988007 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -253,12 +253,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); if (cc >= GGML_CUDA_CC_OFFSET_AMD) { -#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) +#if defined(GGML_HIP_ROCWMMA_FATTN) if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) +#endif // defined(GGML_HIP_ROCWMMA_FATTN) // On AMD the tile kernels perform poorly, use the vec kernel instead: if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { From 29debe14cee4e1c91539a7133453287a8da179fa Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Tue, 25 Feb 2025 22:35:22 +0800 Subject: [PATCH 07/13] Always return false in `fp16_mma_available` when compiling for HIP and GGML_HIP_ROCWMMA_FATTN is disabled. --- ggml/src/ggml-cuda/common.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index edceacaa72790..f90e7c6daa660 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -227,8 +227,12 @@ static bool fast_fp16_hardware_available(const int cc) { // Any FP16 tensor core instructions are available for ggml code. static bool fp16_mma_available(const int cc) { +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) + return false; +#else return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3; +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // To be used for feature selection of external libraries, e.g. cuBLAS. From 5d4ab04cf45a92c3e86325460958243847dd9381 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Tue, 25 Feb 2025 22:41:56 +0800 Subject: [PATCH 08/13] Remove the Q->ne[1] > 8 check --- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 4 ++-- ggml/src/ggml-cuda/fattn.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 6b63150850699..68cfc6a3d1883 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -578,10 +578,10 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten return; } +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { constexpr int cols_per_block = 8; switch (Q->ne[0]) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) case 64: ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); break; @@ -594,13 +594,13 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten case 256: ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) default: GGML_ABORT("fatal error"); break; } return; } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) if (Q->ne[1] <= 32) { constexpr int cols_per_block = 16; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1437019988007..decceeb246af9 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -254,7 +254,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst if (cc >= GGML_CUDA_CC_OFFSET_AMD) { #if defined(GGML_HIP_ROCWMMA_FATTN) - if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) { + if (fp16_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } From 55169095e2a7e96eb35605ea083800bdba25666a Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Tue, 25 Feb 2025 23:12:04 +0800 Subject: [PATCH 09/13] Also always return false in fp16_mma_hardware_available when compiled for AMD and GGML_HIP_ROCWMMA_FATTN not enabled. --- ggml/src/ggml-cuda/common.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index f90e7c6daa660..8eab01b60bf1c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -237,8 +237,12 @@ static bool fp16_mma_available(const int cc) { // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) + return false; +#else return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3; +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. From fea171f548a0d401cb69f78eb2ec7880199b76b6 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Tue, 25 Feb 2025 23:53:31 +0800 Subject: [PATCH 10/13] Revert "Also always return false in fp16_mma_hardware_available when compiled for AMD and GGML_HIP_ROCWMMA_FATTN not enabled." This reverts commit 55169095e2a7e96eb35605ea083800bdba25666a. --- ggml/src/ggml-cuda/common.cuh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8eab01b60bf1c..f90e7c6daa660 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -237,12 +237,8 @@ static bool fp16_mma_available(const int cc) { // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { -#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) - return false; -#else return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3; -#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. From a90f4cb757bef9696dbb0a8a106d50a908ccf2a8 Mon Sep 17 00:00:00 2001 From: Ben Jackson Date: Sun, 2 Mar 2025 18:51:28 -0800 Subject: [PATCH 11/13] ggml: Make fattn use hardware warp size instead of 32 --- ggml/src/ggml-cuda/common.cuh | 1 + ggml/src/ggml-cuda/fattn-common.cuh | 66 ++++++++++++--------- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 89 ++++++++++++++-------------- ggml/src/ggml-cuda/fattn.cu | 3 +- 4 files changed, 87 insertions(+), 72 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index f90e7c6daa660..0a801d7ae9dbc 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -62,6 +62,7 @@ #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 7b9566fb4be32..46de14093545c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -57,12 +57,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -70,7 +71,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const int shift = k_KQ & (QI8_1/2); const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -78,14 +79,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size]; sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; - sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); } } @@ -97,12 +98,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -110,7 +112,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const int shift = k_KQ & (QI8_1/2); const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -118,7 +120,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size]; const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); } else @@ -126,8 +128,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( { const float2 * Q_ds = (const float2 *) Q_ds_v; - const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; - const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; + const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; sum += (T) (sumid4d8 + m4s8scaled); } @@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -161,7 +164,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( v |= (vh << 18) & 0x00100000; // 2 -> 20 v |= (vh << 25) & 0x10000000; // 3 -> 28 - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -169,14 +172,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size]; sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; - sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); } } @@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -208,7 +212,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( v |= (vh << 18) & 0x00100000; // 2 -> 20 v |= (vh << 25) & 0x10000000; // 3 -> 28 - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -216,7 +220,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size]; const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); } else @@ -224,8 +228,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( { const float2 * Q_ds = (const float2 *) Q_ds_v; - const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; - const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; + const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; sum += (T) (sumid5d8 + m5s8scaled); } @@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; @@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( T Q_d; if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]); + Q_d = __low2half(Q_ds[k_KQ_0/warp_size]); } else { const float2 * Q_ds = (const float2 *) Q_ds_v; - Q_d = Q_ds[k_KQ_0/WARP_SIZE].x; + Q_d = Q_ds[k_KQ_0/warp_size].x; } - sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d); + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d); } return sum; @@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { const half2 * K_h2 = (const half2 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); @@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( half2 sum2 = make_half2(0.0f, 0.0f); #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; - sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + sum2 += K_ik * Q_h2[k_KQ_0/warp_size]; } return __low2half(sum2) + __high2half(sum2); @@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( float sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; - sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x; - sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y; + sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x; + sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y; } return sum; @@ -698,6 +704,8 @@ void launch_fattn( GGML_ASSERT(Q->ne[3] == 1); + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; + ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -750,7 +758,7 @@ void launch_fattn( const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; - const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 block_dim(warp_size, nwarps, 1); dim3 blocks_num; if (parallel_blocks == 0) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. @@ -796,6 +804,8 @@ void launch_fattn( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + GGML_ASSERT(block_dim.x % warp_size == 0); + GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size); fattn_kernel<<>>( (const char *) Q->data, K_data, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 68cfc6a3d1883..21e59df6cccc8 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -20,7 +20,7 @@ namespace wmma = rocwmma; // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) +__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -67,6 +67,8 @@ static __global__ void flash_attn_ext_f16( //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. @@ -139,9 +141,9 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D/2; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + warp_size > D/2 && i >= D/2) { break; } VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); @@ -153,9 +155,9 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + warp_size > D && i >= D) { break; } KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; @@ -209,27 +211,27 @@ static __global__ void flash_attn_ext_f16( const int j = j0 + threadIdx.y; if (std::is_same::value) { - float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; + float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k]; if (use_logit_softcap) { - KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]); } } float KQ_max_new = KQ_max_f[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; - KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); } - KQ_max_new = warp_reduce_max(KQ_max_new); + KQ_max_new = warp_reduce_max(KQ_max_new); const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; KQ_max_scale_f[j0/nwarps] = expf(diff); @@ -240,48 +242,48 @@ static __global__ void flash_attn_ext_f16( float KQ_rowsum_add = 0.0f; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; - KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/warp_size] = expf(diff); if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + KQ_f_tmp[k0/warp_size] = 0.0f; } - KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; - KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + KQ_rowsum_add += KQ_f_tmp[k0/warp_size]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size]; } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; } else { - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k]; if (use_logit_softcap) { // There is no dedicated tangens hyperbolicus function for half2. - KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); - KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) - /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f)); - KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; + KQ2_tmp[k0/warp_size] *= logit_softcap_2; } } half2 KQ_max_new = KQ_max_h2[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); } - KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; KQ_max_scale_h2[j0/nwarps] = h2exp(diff); const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); @@ -290,17 +292,17 @@ static __global__ void flash_attn_ext_f16( half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/warp_size] = h2exp(diff); const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/warp_size]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size]; } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; @@ -371,9 +373,9 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D/2; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + warp_size > D/2 && i >= D/2) { break; } @@ -405,9 +407,9 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + warp_size > D && i >= D) { break; } float dst_val = VKQ[j_VKQ*D_padded + i]; @@ -522,6 +524,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * Q = dst->src[0]; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; if (prec != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { @@ -579,7 +582,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten } #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) { constexpr int cols_per_block = 8; switch (Q->ne[0]) { case 64: diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index decceeb246af9..24f973056aa9a 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -250,6 +250,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); if (cc >= GGML_CUDA_CC_OFFSET_AMD) { @@ -298,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int gqa_ratio = Q->ne[2] / K->ne[2]; const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask; - if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) { + if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); return; From a135b4c72e4e4bfa10f1f05b1b930ae271f3b156 Mon Sep 17 00:00:00 2001 From: Ben Jackson Date: Sun, 2 Mar 2025 19:21:04 -0800 Subject: [PATCH 12/13] ggml: Make fattn kernel use launch bounds w/HIP --- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 21e59df6cccc8..622cf28576d29 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -19,9 +19,7 @@ namespace wmma = rocwmma; // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, From 373d48ef061cd3a640284dea02139cb85cf864ec Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Mon, 3 Mar 2025 22:23:27 +0800 Subject: [PATCH 13/13] Use GGML_CUDA_CC_IS_CDNA for checking CDNA architectures. --- ggml/src/ggml-cuda/common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 0a801d7ae9dbc..1832314ec133b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -232,14 +232,14 @@ static bool fp16_mma_available(const int cc) { return false; #else return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA || - cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3; + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA || - cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3; + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.