@@ -26,9 +26,24 @@ namespace wmma = rocwmma;
2626#endif // !defined(GGML_USE_HIP)
2727#endif // GGML_USE_WMMA_FATTN
2828
29+ #if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
30+ static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 2 ;
31+ #else
32+ static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 1 ;
33+ #endif
34+
35+ template <int D>
36+ constexpr int ggml_wmma_fattn_kq_stride () {
37+ #if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
38+ return D <= 128 ? 128 : FATTN_KQ_STRIDE;
39+ #else
40+ return FATTN_KQ_STRIDE;
41+ #endif
42+ }
43+
2944// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
3045template <int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
31- __launch_bounds__ (nwarps*ggml_cuda_get_physical_warp_size (), 1 )
46+ __launch_bounds__ (nwarps*ggml_cuda_get_physical_warp_size (), GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM )
3247static __global__ void flash_attn_ext_f16(
3348 const char * __restrict__ Q,
3449 const char * __restrict__ K,
@@ -61,10 +76,12 @@ static __global__ void flash_attn_ext_f16(
6176 // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
6277
6378 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
79+ constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();
80+
81+ static_assert (D <= fattn_kq_stride, " D must be <= fattn_kq_stride." );
6482
6583 const int ic0 = ncols*blockIdx .x ; // Index of the first Q/QKV column to work on.
6684
67- static_assert (D <= FATTN_KQ_STRIDE, " D must be <= FATTN_KQ_STRIDE." );
6885 static_assert (ncols == 8 || ncols % 16 == 0 , " ncols must be 8 or a multiple of 16." );
6986 constexpr int frag_m = ncols == 8 ? 32 : 16 ;
7087 constexpr int frag_n = ncols == 8 ? 8 : 16 ;
@@ -81,7 +98,7 @@ static __global__ void flash_attn_ext_f16(
8198
8299 // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
83100 constexpr int D_padded = D + 8 ;
84- constexpr int kqs_padded = FATTN_KQ_STRIDE + 8 ;
101+ constexpr int kqs_padded = fattn_kq_stride + 8 ;
85102 constexpr int kqar = sizeof (KQ_acc_t)/sizeof (half);
86103
87104 const int sequence = blockIdx .z / ne02;
@@ -174,10 +191,10 @@ static __global__ void flash_attn_ext_f16(
174191
175192 // Iterate over ne11 == previous tokens:
176193 const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim .x + blockIdx .x ] : ne11;
177- for (int k_VKQ_0 = blockIdx .y *FATTN_KQ_STRIDE ; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim .y *FATTN_KQ_STRIDE ) {
194+ for (int k_VKQ_0 = blockIdx .y *fattn_kq_stride ; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim .y *fattn_kq_stride ) {
178195 // Calculate tile of KQ:
179196#pragma unroll
180- for (int i_KQ_0 = 0 ; i_KQ_0 < FATTN_KQ_STRIDE ; i_KQ_0 += KQ_stride_tc) {
197+ for (int i_KQ_0 = 0 ; i_KQ_0 < fattn_kq_stride ; i_KQ_0 += KQ_stride_tc) {
181198 frag_c_KQ KQ_c[ncols/frag_n];
182199#pragma unroll
183200 for (int j = 0 ; j < ncols/frag_n; ++j) {
@@ -207,9 +224,9 @@ static __global__ void flash_attn_ext_f16(
207224 const int j = j0 + threadIdx .y ;
208225
209226 if (std::is_same<KQ_acc_t, float >::value) {
210- float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
227+ float KQ_f_tmp[fattn_kq_stride / warp_size];
211228#pragma unroll
212- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += warp_size) {
229+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += warp_size) {
213230 const int k = k0 + threadIdx .x ;
214231
215232 KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
@@ -221,7 +238,7 @@ static __global__ void flash_attn_ext_f16(
221238
222239 float KQ_max_new = KQ_max_f[j0/nwarps];
223240#pragma unroll
224- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += warp_size) {
241+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += warp_size) {
225242 const int k = k0 + threadIdx .x ;
226243
227244 KQ_f_tmp[k0/warp_size] += mask ? __half2float (slopeh*maskh[j*(nb31/sizeof (half)) + k_VKQ_0 + k]) : 0 .0f ;
@@ -238,7 +255,7 @@ static __global__ void flash_attn_ext_f16(
238255
239256 float KQ_rowsum_add = 0 .0f ;
240257#pragma unroll
241- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += warp_size) {
258+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += warp_size) {
242259 const int k = k0 + threadIdx .x ;
243260
244261 const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
@@ -254,9 +271,9 @@ static __global__ void flash_attn_ext_f16(
254271 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
255272 KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
256273 } else {
257- half2 KQ2_tmp[FATTN_KQ_STRIDE /(2 *warp_size)];
274+ half2 KQ2_tmp[fattn_kq_stride /(2 *warp_size)];
258275#pragma unroll
259- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE /2 ; k0 += warp_size) {
276+ for (int k0 = 0 ; k0 < fattn_kq_stride /2 ; k0 += warp_size) {
260277 const int k = k0 + threadIdx .x ;
261278
262279 KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2 ) + k];
@@ -273,7 +290,7 @@ static __global__ void flash_attn_ext_f16(
273290
274291 half2 KQ_max_new = KQ_max_h2[j0/nwarps];
275292#pragma unroll
276- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE /2 ; k0 += warp_size) {
293+ for (int k0 = 0 ; k0 < fattn_kq_stride /2 ; k0 += warp_size) {
277294 const int k = k0 + threadIdx .x ;
278295
279296 KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
@@ -288,7 +305,7 @@ static __global__ void flash_attn_ext_f16(
288305
289306 half2 KQ_rowsum_add = make_half2 (0 .0f , 0 .0f );
290307#pragma unroll
291- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE /2 ; k0 += warp_size) {
308+ for (int k0 = 0 ; k0 < fattn_kq_stride /2 ; k0 += warp_size) {
292309 const int k = k0 + threadIdx .x ;
293310
294311 const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
@@ -307,11 +324,11 @@ static __global__ void flash_attn_ext_f16(
307324
308325 __syncthreads ();
309326
310- frag_b KQ_b[FATTN_KQ_STRIDE /(VKQ_ratio*16 )][ncols/frag_n];
327+ frag_b KQ_b[fattn_kq_stride /(VKQ_ratio*16 )][ncols/frag_n];
311328#pragma unroll
312329 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
313330#pragma unroll
314- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += VKQ_ratio*16 ) {
331+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += VKQ_ratio*16 ) {
315332 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
316333 wmma::load_matrix_sync (
317334 KQ_b[k0/(VKQ_ratio*16 )][j0/frag_n],
@@ -329,7 +346,7 @@ static __global__ void flash_attn_ext_f16(
329346 }
330347
331348#pragma unroll
332- for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE ; k0 += VKQ_ratio*16 ) {
349+ for (int k0 = 0 ; k0 < fattn_kq_stride ; k0 += VKQ_ratio*16 ) {
333350 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
334351
335352 frag_a_V v_a;
@@ -518,7 +535,12 @@ template <int D, int cols_per_block, typename KQ_acc_t>
518535void ggml_cuda_flash_attn_ext_wmma_f16_case (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
519536 const ggml_tensor * KQV = dst;
520537
538+ #if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
539+ constexpr int nwarps = D <= 96 ? 8 : 4 ;
540+ #else
521541 constexpr int nwarps = 4 ;
542+ #endif
543+ constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();
522544
523545 constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16 ;
524546 const int warp_size = ggml_cuda_info ().devices [ggml_cuda_get_device ()].warp_size ;
@@ -536,7 +558,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
536558 fattn_kernel = flash_attn_ext_f16<
537559 D, cols_per_block, nwarps, get_VKQ_stride (D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
538560 }
539- launch_fattn<D, cols_per_block, 1 >(ctx, dst, fattn_kernel, nwarps, 0 , FATTN_KQ_STRIDE , true , true , false , warp_size);
561+ launch_fattn<D, cols_per_block, 1 >(ctx, dst, fattn_kernel, nwarps, 0 , fattn_kq_stride , true , true , false , warp_size);
540562}
541563
542564void ggml_cuda_flash_attn_ext_wmma_f16 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
0 commit comments