-
Notifications
You must be signed in to change notification settings - Fork 13.7k
Add GGML_HIP_ROCWMMA_FATTN to enable rocWMMA for FlashAttention #12032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
206d22b
02369da
547115d
419f1ea
828577a
9d27c38
19272bf
29debe1
5d4ab04
5516909
fea171f
a90f4cb
a135b4c
373d48e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,12 @@ | |
| #include "fattn-wmma-f16.cuh" | ||
|
|
||
| #ifdef FP16_MMA_AVAILABLE | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| #include <mma.h> | ||
| #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) | ||
| #undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers | ||
| #include <rocwmma/rocwmma.hpp> | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| #endif // FP16_MMA_AVAILABLE | ||
|
|
||
| // D == head size, VKQ_stride == num VKQ rows calculated in parallel: | ||
|
|
@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16( | |
| const int ne1, | ||
| const int ne2, | ||
| const int ne3) { | ||
| #if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) | ||
| // Skip unused kernel variants for faster compilation: | ||
| if (use_logit_softcap && !(D == 128 || D == 256)) { | ||
| NO_DEVICE_CODE; | ||
|
|
@@ -68,11 +73,19 @@ static __global__ void flash_attn_ext_f16( | |
| constexpr int frag_m = ncols == 8 ? 32 : 16; | ||
| constexpr int frag_n = ncols == 8 ? 8 : 16; | ||
| static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K; | ||
| typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V; | ||
| typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b; | ||
| typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; | ||
| typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; | ||
| #else | ||
| typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::row_major> frag_a_K; | ||
| typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::col_major> frag_a_V; | ||
| typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16, half, rocwmma::col_major> frag_b; | ||
| typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; | ||
| typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
|
||
|
|
||
| constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. | ||
| constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. | ||
|
|
@@ -162,7 +175,11 @@ static __global__ void flash_attn_ext_f16( | |
| for (int i0 = 0; i0 < D; i0 += 16) { | ||
| #pragma unroll | ||
| for (int j0 = 0; j0 < ncols; j0 += frag_n) { | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); | ||
| #else | ||
| rocwmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -176,20 +193,36 @@ static __global__ void flash_attn_ext_f16( | |
| frag_c_KQ KQ_c[ncols/frag_n]; | ||
| #pragma unroll | ||
| for (int j = 0; j < ncols/frag_n; ++j) { | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); | ||
| #else | ||
| rocwmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f)); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| } | ||
| #pragma unroll | ||
| for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { | ||
| frag_a_K K_a; | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); | ||
| #else | ||
| rocwmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| #pragma unroll | ||
| for (int j = 0; j < ncols/frag_n; ++j) { | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); | ||
| #else | ||
| rocwmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| } | ||
| } | ||
| #pragma unroll | ||
| for (int j0 = 0; j0 < ncols; j0 += frag_n) { | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); | ||
| #else | ||
| rocwmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, rocwmma::mem_col_major); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -308,10 +341,17 @@ static __global__ void flash_attn_ext_f16( | |
| #pragma unroll | ||
| for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { | ||
| const int k = k0 + (threadIdx.y % VKQ_ratio)*16; | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::load_matrix_sync( | ||
| KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], | ||
| KQ + j0*(kqar*kqs_padded) + k, | ||
| kqar*kqs_padded); | ||
| #else | ||
| rocwmma::load_matrix_sync( | ||
| KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], | ||
| KQ + j0*(kqar*kqs_padded) + k, | ||
| kqar*kqs_padded); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -320,18 +360,30 @@ static __global__ void flash_attn_ext_f16( | |
| for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { | ||
| #pragma unroll | ||
| for (int j = 0; j < ncols/frag_n; ++j) { | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); | ||
| #else | ||
| rocwmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f)); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| } | ||
|
|
||
| #pragma unroll | ||
| for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { | ||
| const int k = k0 + (threadIdx.y % VKQ_ratio)*16; | ||
|
|
||
| frag_a_V v_a; | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); | ||
| #else | ||
| rocwmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| #pragma unroll | ||
| for (int j = 0; j < ncols/frag_n; ++j) { | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); | ||
| #else | ||
| rocwmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -343,10 +395,17 @@ static __global__ void flash_attn_ext_f16( | |
| for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { | ||
| #pragma unroll | ||
| for (int j0 = 0; j0 < ncols; j0 += frag_n) { | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| nvcuda::wmma::store_matrix_sync( | ||
| KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), | ||
| VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], | ||
| D_padded, nvcuda::wmma::mem_col_major); | ||
| #else | ||
| rocwmma::store_matrix_sync( | ||
| KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), | ||
| VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], | ||
| D_padded, rocwmma::mem_col_major); | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -425,7 +484,7 @@ static __global__ void flash_attn_ext_f16( | |
| } | ||
| #else | ||
| NO_DEVICE_CODE; | ||
| #endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) | ||
| } | ||
|
|
||
| constexpr int get_max_power_of_2(int x) { | ||
|
|
@@ -574,6 +633,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten | |
| if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { | ||
| constexpr int cols_per_block = 8; | ||
| switch (Q->ne[0]) { | ||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| case 64: | ||
| ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); | ||
| break; | ||
|
|
@@ -586,6 +646,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten | |
| case 256: | ||
| ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); | ||
| break; | ||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
| default: | ||
| GGML_ABORT("fatal error"); | ||
| break; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.