@@ -20,9 +20,24 @@ namespace wmma = rocwmma;
2020#endif // !defined(GGML_USE_HIP)
2121#endif // GGML_USE_WMMA_FATTN
2222
23+ #if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
24+ static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 2 ;
25+ #else
26+ static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 1 ;
27+ #endif
28+
29+ template <int D>
30+ constexpr int ggml_wmma_fattn_kq_stride () {
31+ #if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
32+ return D <= 128 ? 128 : FATTN_KQ_STRIDE;
33+ #else
34+ return FATTN_KQ_STRIDE;
35+ #endif
36+ }
37+
2338// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
2439template <int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
25- __launch_bounds__ (nwarps*ggml_cuda_get_physical_warp_size (), 1 )
40+ __launch_bounds__ (nwarps*ggml_cuda_get_physical_warp_size (), GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM )
2641static __global__ void flash_attn_ext_f16(
2742 const char * __restrict__ Q,
2843 const char * __restrict__ K,
@@ -55,10 +70,12 @@ static __global__ void flash_attn_ext_f16(
5570 // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
5671
5772 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
73+ constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();
74+
75+ static_assert (D <= fattn_kq_stride, " D must be <= fattn_kq_stride." );
5876
5977 const int ic0 = ncols*blockIdx .x ; // Index of the first Q/QKV column to work on.
6078
61- static_assert (D <= FATTN_KQ_STRIDE, " D must be <= FATTN_KQ_STRIDE." );
6279 static_assert (ncols == 8 || ncols % 16 == 0 , " ncols must be 8 or a multiple of 16." );
6380 constexpr int frag_m = ncols == 8 ? 32 : 16 ;
6481 constexpr int frag_n = ncols == 8 ? 8 : 16 ;
@@ -75,7 +92,7 @@ static __global__ void flash_attn_ext_f16(
7592
7693 // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
7794 constexpr int D_padded = D + 8 ;
78- constexpr int kqs_padded = FATTN_KQ_STRIDE + 8 ;
95+ constexpr int kqs_padded = fattn_kq_stride + 8 ;
7996 constexpr int kqar = sizeof (KQ_acc_t)/sizeof (half);
8097
8198 const int sequence = blockIdx .z / ne02;
@@ -168,10 +185,10 @@ static __global__ void flash_attn_ext_f16(
168185
169186 // Iterate over ne11 == previous tokens:
170187 const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim .x + blockIdx .x ] : ne11;
171- for (int k_VKQ_0 = blockIdx .y *FATTN_KQ_STRIDE ; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim .y *FATTN_KQ_STRIDE ) {
188+ for (int k_VKQ_0 = blockIdx .y *fattn_kq_stride ; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim .y *fattn_kq_stride ) {
172189 // Calculate tile of KQ:
173190#pragma unroll
174- for (int i_KQ_0 = 0 ; i_KQ_0 < FATTN_KQ_STRIDE ; i_KQ_0 += KQ_stride_tc) {
191+ for (int i_KQ_0 = 0 ; i_KQ_0 < fattn_kq_stride ; i_KQ_0 += KQ_stride_tc) {
175192 frag_c_KQ KQ_c[ncols/frag_n];
176193#pragma unroll
177194 for (int j = 0 ; j < ncols/frag_n; ++j) {
@@ -201,9 +218,9 @@ static __global__ void flash_attn_ext_f16(
201218 const int j = j0 + threadIdx .y ;
202219
203220 if (std::is_same<KQ_acc_t, float >::value) {
204- float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
221+ float KQ_f_tmp[fattn_kq_stride / warp_size];
205222#pragma unroll
206- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += warp_size) {
223+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += warp_size) {
207224 const int k = k0 + threadIdx .x ;
208225
209226 KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
@@ -215,7 +232,7 @@ static __global__ void flash_attn_ext_f16(
215232
216233 float KQ_max_new = KQ_max_f[j0/nwarps];
217234#pragma unroll
218- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += warp_size) {
235+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += warp_size) {
219236 const int k = k0 + threadIdx .x ;
220237
221238 KQ_f_tmp[k0/warp_size] += mask ? __half2float (slopeh*maskh[j*(nb31/sizeof (half)) + k_VKQ_0 + k]) : 0 .0f ;
@@ -232,7 +249,7 @@ static __global__ void flash_attn_ext_f16(
232249
233250 float KQ_rowsum_add = 0 .0f ;
234251#pragma unroll
235- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += warp_size) {
252+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += warp_size) {
236253 const int k = k0 + threadIdx .x ;
237254
238255 const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
@@ -248,9 +265,9 @@ static __global__ void flash_attn_ext_f16(
248265 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
249266 KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
250267 } else {
251- half2 KQ2_tmp[FATTN_KQ_STRIDE /(2 *warp_size)];
268+ half2 KQ2_tmp[fattn_kq_stride /(2 *warp_size)];
252269#pragma unroll
253- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE /2 ; k0 += warp_size) {
270+ for (int k0 = 0 ; k0 < fattn_kq_stride /2 ; k0 += warp_size) {
254271 const int k = k0 + threadIdx .x ;
255272
256273 KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2 ) + k];
@@ -267,7 +284,7 @@ static __global__ void flash_attn_ext_f16(
267284
268285 half2 KQ_max_new = KQ_max_h2[j0/nwarps];
269286#pragma unroll
270- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE /2 ; k0 += warp_size) {
287+ for (int k0 = 0 ; k0 < fattn_kq_stride /2 ; k0 += warp_size) {
271288 const int k = k0 + threadIdx .x ;
272289
273290 KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
@@ -282,7 +299,7 @@ static __global__ void flash_attn_ext_f16(
282299
283300 half2 KQ_rowsum_add = make_half2 (0 .0f , 0 .0f );
284301#pragma unroll
285- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE /2 ; k0 += warp_size) {
302+ for (int k0 = 0 ; k0 < fattn_kq_stride /2 ; k0 += warp_size) {
286303 const int k = k0 + threadIdx .x ;
287304
288305 const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
@@ -301,11 +318,11 @@ static __global__ void flash_attn_ext_f16(
301318
302319 __syncthreads ();
303320
304- frag_b KQ_b[FATTN_KQ_STRIDE /(VKQ_ratio*16 )][ncols/frag_n];
321+ frag_b KQ_b[fattn_kq_stride /(VKQ_ratio*16 )][ncols/frag_n];
305322#pragma unroll
306323 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
307324#pragma unroll
308- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += VKQ_ratio*16 ) {
325+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += VKQ_ratio*16 ) {
309326 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
310327 wmma::load_matrix_sync (
311328 KQ_b[k0/(VKQ_ratio*16 )][j0/frag_n],
@@ -323,7 +340,7 @@ static __global__ void flash_attn_ext_f16(
323340 }
324341
325342#pragma unroll
326- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += VKQ_ratio*16 ) {
343+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += VKQ_ratio*16 ) {
327344 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
328345
329346 frag_a_V v_a;
@@ -512,7 +529,12 @@ template <int D, int cols_per_block, typename KQ_acc_t>
512529void ggml_cuda_flash_attn_ext_wmma_f16_case (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
513530 const ggml_tensor * KQV = dst;
514531
532+ #if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
533+ constexpr int nwarps = D <= 96 ? 8 : 4 ;
534+ #else
515535 constexpr int nwarps = 4 ;
536+ #endif
537+ constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();
516538
517539 constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16 ;
518540 const int warp_size = ggml_cuda_info ().devices [ggml_cuda_get_device ()].warp_size ;
@@ -530,7 +552,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
530552 fattn_kernel = flash_attn_ext_f16<
531553 D, cols_per_block, nwarps, get_VKQ_stride (D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
532554 }
533- launch_fattn<D, cols_per_block, 1 >(ctx, dst, fattn_kernel, nwarps, 0 , FATTN_KQ_STRIDE , true , true , false , warp_size);
555+ launch_fattn<D, cols_per_block, 1 >(ctx, dst, fattn_kernel, nwarps, 0 , fattn_kq_stride , true , true , false , warp_size);
534556}
535557
536558void ggml_cuda_flash_attn_ext_wmma_f16 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
0 commit comments