99#ifdef FP16_MMA_AVAILABLE
1010#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111#include < mma.h>
12+ namespace wmma = nvcuda::wmma;
1213#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1314#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1415#include < rocwmma/rocwmma.hpp>
16+ namespace wmma = rocwmma;
1517#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1618#endif // FP16_MMA_AVAILABLE
1719
@@ -73,19 +75,11 @@ static __global__ void flash_attn_ext_f16(
7375 constexpr int frag_m = ncols == 8 ? 32 : 16 ;
7476 constexpr int frag_n = ncols == 8 ? 8 : 16 ;
7577 static_assert (D % frag_m == 0 , " If ncols == 8 then D % frag_m must be 0." );
76- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
77- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16 , half, nvcuda::wmma::row_major> frag_a_K;
78- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16 , half, nvcuda::wmma::col_major> frag_a_V;
79- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16 , half, nvcuda::wmma::col_major> frag_b;
80- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
81- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
82- #else
83- typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16 , half, rocwmma::row_major> frag_a_K;
84- typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16 , half, rocwmma::col_major> frag_a_V;
85- typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16 , half, rocwmma::col_major> frag_b;
86- typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
87- typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
88- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
78+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , half, wmma::row_major> frag_a_K;
79+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , half, wmma::col_major> frag_a_V;
80+ typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16 , half, wmma::col_major> frag_b;
81+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
82+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
8983
9084 constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
9185 constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -175,11 +169,7 @@ static __global__ void flash_attn_ext_f16(
175169 for (int i0 = 0 ; i0 < D; i0 += 16 ) {
176170#pragma unroll
177171 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
178- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
179- nvcuda::wmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
180- #else
181- rocwmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
182- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
172+ wmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
183173 }
184174 }
185175
@@ -193,36 +183,20 @@ static __global__ void flash_attn_ext_f16(
193183 frag_c_KQ KQ_c[ncols/frag_n];
194184#pragma unroll
195185 for (int j = 0 ; j < ncols/frag_n; ++j) {
196- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
197- nvcuda::wmma::fill_fragment (KQ_c[j], 0 .0f );
198- #else
199- rocwmma::fill_fragment (KQ_c[j], static_cast <KQ_acc_t>(0 .0f ));
200- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
186+ wmma::fill_fragment (KQ_c[j], static_cast <KQ_acc_t>(0 .0f ));
201187 }
202188#pragma unroll
203189 for (int k_KQ_0 = 0 ; k_KQ_0 < D; k_KQ_0 += 16 ) {
204190 frag_a_K K_a;
205- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
206- nvcuda::wmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
207- #else
208- rocwmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
209- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
191+ wmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
210192#pragma unroll
211193 for (int j = 0 ; j < ncols/frag_n; ++j) {
212- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
213- nvcuda::wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
214- #else
215- rocwmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
216- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
194+ wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
217195 }
218196 }
219197#pragma unroll
220198 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
221- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
222- nvcuda::wmma::store_matrix_sync ((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
223- #else
224- rocwmma::store_matrix_sync ((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], kqs_padded, rocwmma::mem_col_major);
225- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
199+ wmma::store_matrix_sync ((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
226200 }
227201 }
228202
@@ -341,17 +315,10 @@ static __global__ void flash_attn_ext_f16(
341315#pragma unroll
342316 for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16 ) {
343317 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
344- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
345- nvcuda::wmma::load_matrix_sync (
318+ wmma::load_matrix_sync (
346319 KQ_b[k0/(VKQ_ratio*16 )][j0/frag_n],
347320 KQ + j0*(kqar*kqs_padded) + k,
348321 kqar*kqs_padded);
349- #else
350- rocwmma::load_matrix_sync (
351- KQ_b[k0/(VKQ_ratio*16 )][j0/frag_n],
352- KQ + j0*(kqar*kqs_padded) + k,
353- kqar*kqs_padded);
354- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
355322 }
356323 }
357324
@@ -360,30 +327,18 @@ static __global__ void flash_attn_ext_f16(
360327 for (int i_VKQ_0 = 0 ; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
361328#pragma unroll
362329 for (int j = 0 ; j < ncols/frag_n; ++j) {
363- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
364- nvcuda::wmma::fill_fragment (VKQ_c[i_VKQ_0/VKQ_stride][j], 0 .0f );
365- #else
366- rocwmma::fill_fragment (VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast <half>(0 .0f ));
367- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
330+ wmma::fill_fragment (VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast <half>(0 .0f ));
368331 }
369332
370333#pragma unroll
371334 for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16 ) {
372335 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
373336
374337 frag_a_V v_a;
375- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
376- nvcuda::wmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
377- #else
378- rocwmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
379- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
338+ wmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
380339#pragma unroll
381340 for (int j = 0 ; j < ncols/frag_n; ++j) {
382- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
383- nvcuda::wmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
384- #else
385- rocwmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
386- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
341+ wmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
387342 }
388343 }
389344 }
@@ -395,17 +350,10 @@ static __global__ void flash_attn_ext_f16(
395350 for (int i_KQ_0 = 0 ; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
396351#pragma unroll
397352 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
398- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
399- nvcuda::wmma::store_matrix_sync (
353+ wmma::store_matrix_sync (
400354 KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx .y /VKQ_ratio),
401355 VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
402- D_padded, nvcuda::wmma::mem_col_major);
403- #else
404- rocwmma::store_matrix_sync (
405- KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx .y /VKQ_ratio),
406- VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
407- D_padded, rocwmma::mem_col_major);
408- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
356+ D_padded, wmma::mem_col_major);
409357 }
410358 }
411359
0 commit comments