77#include " fattn-wmma-f16.cuh"
88
99#ifdef FP16_MMA_AVAILABLE
10+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1011#include < mma.h>
12+ #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
13+ #undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
14+ #include < rocwmma/rocwmma.hpp>
15+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1116#endif // FP16_MMA_AVAILABLE
1217
1318// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
5156 const int ne1,
5257 const int ne2,
5358 const int ne3) {
54- #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
59+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))
5560 // Skip unused kernel variants for faster compilation:
5661 if (use_logit_softcap && !(D == 128 || D == 256 )) {
5762 NO_DEVICE_CODE;
@@ -68,11 +73,19 @@ static __global__ void flash_attn_ext_f16(
6873 constexpr int frag_m = ncols == 8 ? 32 : 16 ;
6974 constexpr int frag_n = ncols == 8 ? 8 : 16 ;
7075 static_assert (D % frag_m == 0 , " If ncols == 8 then D % frag_m must be 0." );
76+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
7177 typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16 , half, nvcuda::wmma::row_major> frag_a_K;
7278 typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16 , half, nvcuda::wmma::col_major> frag_a_V;
7379 typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16 , half, nvcuda::wmma::col_major> frag_b;
7480 typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
7581 typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
82+ #else
83+ typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16 , half, rocwmma::row_major> frag_a_K;
84+ typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16 , half, rocwmma::col_major> frag_a_V;
85+ typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16 , half, rocwmma::col_major> frag_b;
86+ typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
87+ typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
88+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
7689
7790 constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
7891 constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -162,7 +175,11 @@ static __global__ void flash_attn_ext_f16(
162175 for (int i0 = 0 ; i0 < D; i0 += 16 ) {
163176#pragma unroll
164177 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
178+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
165179 nvcuda::wmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
180+ #else
181+ rocwmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
182+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
166183 }
167184 }
168185
@@ -176,20 +193,36 @@ static __global__ void flash_attn_ext_f16(
176193 frag_c_KQ KQ_c[ncols/frag_n];
177194#pragma unroll
178195 for (int j = 0 ; j < ncols/frag_n; ++j) {
196+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
179197 nvcuda::wmma::fill_fragment (KQ_c[j], 0 .0f );
198+ #else
199+ rocwmma::fill_fragment (KQ_c[j], static_cast <KQ_acc_t>(0 .0f ));
200+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
180201 }
181202#pragma unroll
182203 for (int k_KQ_0 = 0 ; k_KQ_0 < D; k_KQ_0 += 16 ) {
183204 frag_a_K K_a;
205+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
184206 nvcuda::wmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
207+ #else
208+ rocwmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
209+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
185210#pragma unroll
186211 for (int j = 0 ; j < ncols/frag_n; ++j) {
212+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
187213 nvcuda::wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
214+ #else
215+ rocwmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
216+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
188217 }
189218 }
190219#pragma unroll
191220 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
221+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
192222 nvcuda::wmma::store_matrix_sync ((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
223+ #else
224+ rocwmma::store_matrix_sync ((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], kqs_padded, rocwmma::mem_col_major);
225+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
193226 }
194227 }
195228
@@ -308,10 +341,17 @@ static __global__ void flash_attn_ext_f16(
308341#pragma unroll
309342 for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16 ) {
310343 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
344+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
311345 nvcuda::wmma::load_matrix_sync (
312346 KQ_b[k0/(VKQ_ratio*16 )][j0/frag_n],
313347 KQ + j0*(kqar*kqs_padded) + k,
314348 kqar*kqs_padded);
349+ #else
350+ rocwmma::load_matrix_sync (
351+ KQ_b[k0/(VKQ_ratio*16 )][j0/frag_n],
352+ KQ + j0*(kqar*kqs_padded) + k,
353+ kqar*kqs_padded);
354+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
315355 }
316356 }
317357
@@ -320,18 +360,30 @@ static __global__ void flash_attn_ext_f16(
320360 for (int i_VKQ_0 = 0 ; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
321361#pragma unroll
322362 for (int j = 0 ; j < ncols/frag_n; ++j) {
363+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
323364 nvcuda::wmma::fill_fragment (VKQ_c[i_VKQ_0/VKQ_stride][j], 0 .0f );
365+ #else
366+ rocwmma::fill_fragment (VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast <half>(0 .0f ));
367+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
324368 }
325369
326370#pragma unroll
327371 for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16 ) {
328372 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
329373
330374 frag_a_V v_a;
375+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
331376 nvcuda::wmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
377+ #else
378+ rocwmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
379+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
332380#pragma unroll
333381 for (int j = 0 ; j < ncols/frag_n; ++j) {
382+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
334383 nvcuda::wmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
384+ #else
385+ rocwmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
386+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
335387 }
336388 }
337389 }
@@ -343,10 +395,17 @@ static __global__ void flash_attn_ext_f16(
343395 for (int i_KQ_0 = 0 ; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
344396#pragma unroll
345397 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
398+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
346399 nvcuda::wmma::store_matrix_sync (
347400 KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx .y /VKQ_ratio),
348401 VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
349402 D_padded, nvcuda::wmma::mem_col_major);
403+ #else
404+ rocwmma::store_matrix_sync (
405+ KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx .y /VKQ_ratio),
406+ VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
407+ D_padded, rocwmma::mem_col_major);
408+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
350409 }
351410 }
352411
@@ -425,7 +484,7 @@ static __global__ void flash_attn_ext_f16(
425484 }
426485#else
427486 NO_DEVICE_CODE;
428- #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
487+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))
429488}
430489
431490constexpr int get_max_power_of_2 (int x) {
@@ -574,6 +633,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
574633 if (Q->ne [1 ] <= 8 && Q->ne [0 ] % WARP_SIZE == 0 ) {
575634 constexpr int cols_per_block = 8 ;
576635 switch (Q->ne [0 ]) {
636+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
577637 case 64 :
578638 ggml_cuda_flash_attn_ext_wmma_f16_case< 64 , cols_per_block, half>(ctx, dst);
579639 break ;
@@ -586,6 +646,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
586646 case 256 :
587647 ggml_cuda_flash_attn_ext_wmma_f16_case<256 , cols_per_block, half>(ctx, dst);
588648 break ;
649+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
589650 default :
590651 GGML_ABORT (" fatal error" );
591652 break ;
0 commit comments