22#include " fattn-common.cuh"
33
44#ifdef FP16_MMA_AVAILABLE
5+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
6+ #include < rocwmma/rocwmma.hpp>
7+ namespace wmma = ::rocwmma;
8+ #else
59#include < mma.h>
10+ namespace wmma = ::nvcuda::wmma;
11+ #endif
612#endif // FP16_MMA_AVAILABLE
713
814// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
@@ -63,11 +69,11 @@ static __global__ void flash_attn_ext_f16(
6369 constexpr int frag_m = ncols == 8 ? 32 : 16 ;
6470 constexpr int frag_n = ncols == 8 ? 8 : 16 ;
6571 static_assert (D % frag_m == 0 , " If ncols == 8 then D % frag_m must be 0." );
66- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::matrix_a, frag_m, frag_n, 16 , half, nvcuda:: wmma::row_major> frag_a_K;
67- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::matrix_a, frag_m, frag_n, 16 , half, nvcuda:: wmma::col_major> frag_a_V;
68- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::matrix_b, frag_m, frag_n, 16 , half, nvcuda:: wmma::col_major> frag_b;
69- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
70- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
72+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , half, wmma::row_major> frag_a_K;
73+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , half, wmma::col_major> frag_a_V;
74+ typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16 , half, wmma::col_major> frag_b;
75+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
76+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
7177
7278 constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
7379 constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -157,7 +163,7 @@ static __global__ void flash_attn_ext_f16(
157163 for (int i0 = 0 ; i0 < D; i0 += 16 ) {
158164#pragma unroll
159165 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
160- nvcuda:: wmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
166+ wmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
161167 }
162168 }
163169
@@ -171,20 +177,20 @@ static __global__ void flash_attn_ext_f16(
171177 frag_c_KQ KQ_c[ncols/frag_n];
172178#pragma unroll
173179 for (int j = 0 ; j < ncols/frag_n; ++j) {
174- nvcuda:: wmma::fill_fragment (KQ_c[j], 0 .0f );
180+ wmma::fill_fragment (KQ_c[j], KQ_acc_t ( 0 .0f ) );
175181 }
176182#pragma unroll
177183 for (int k_KQ_0 = 0 ; k_KQ_0 < D; k_KQ_0 += 16 ) {
178184 frag_a_K K_a;
179- nvcuda:: wmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
185+ wmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
180186#pragma unroll
181187 for (int j = 0 ; j < ncols/frag_n; ++j) {
182- nvcuda:: wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
188+ wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
183189 }
184190 }
185191#pragma unroll
186192 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
187- nvcuda:: wmma::store_matrix_sync ((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], kqs_padded, nvcuda:: wmma::mem_col_major);
193+ wmma::store_matrix_sync ((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
188194 }
189195 }
190196
@@ -303,7 +309,7 @@ static __global__ void flash_attn_ext_f16(
303309#pragma unroll
304310 for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16 ) {
305311 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
306- nvcuda:: wmma::load_matrix_sync (
312+ wmma::load_matrix_sync (
307313 KQ_b[k0/(VKQ_ratio*16 )][j0/frag_n],
308314 KQ + j0*(kqar*kqs_padded) + k,
309315 kqar*kqs_padded);
@@ -315,18 +321,18 @@ static __global__ void flash_attn_ext_f16(
315321 for (int i_VKQ_0 = 0 ; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
316322#pragma unroll
317323 for (int j = 0 ; j < ncols/frag_n; ++j) {
318- nvcuda:: wmma::fill_fragment (VKQ_c[i_VKQ_0/VKQ_stride][j], 0 .0f );
324+ wmma::fill_fragment (VKQ_c[i_VKQ_0/VKQ_stride][j], half ( 0 .0f ) );
319325 }
320326
321327#pragma unroll
322328 for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16 ) {
323329 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
324330
325331 frag_a_V v_a;
326- nvcuda:: wmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
332+ wmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
327333#pragma unroll
328334 for (int j = 0 ; j < ncols/frag_n; ++j) {
329- nvcuda:: wmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
335+ wmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
330336 }
331337 }
332338 }
@@ -338,10 +344,10 @@ static __global__ void flash_attn_ext_f16(
338344 for (int i_KQ_0 = 0 ; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
339345#pragma unroll
340346 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
341- nvcuda:: wmma::store_matrix_sync (
347+ wmma::store_matrix_sync (
342348 KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx .y /VKQ_ratio),
343349 VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
344- D_padded, nvcuda:: wmma::mem_col_major);
350+ D_padded, wmma::mem_col_major);
345351 }
346352 }
347353
0 commit comments