@@ -301,13 +301,66 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
301301 }
302302
303303 // Use the WMMA kernel if possible:
304- if (ggml_cuda_should_use_wmma_fattn (cc) && K->ne [1 ] % FATTN_KQ_STRIDE == 0 && Q->ne [0 ] != 40 && Q->ne [0 ] != 576 ) {
304+ #if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
305+ const bool hip_wmma_decode = Q->ne [1 ] == 1 ;
306+ #else
307+ const bool hip_wmma_decode = false ;
308+ #endif
309+ if (!hip_wmma_decode && ggml_cuda_should_use_wmma_fattn (cc) && K->ne [1 ] % FATTN_KQ_STRIDE == 0 && Q->ne [0 ] != 40 && Q->ne [0 ] != 576 ) {
305310 if (can_use_vector_kernel && Q->ne [1 ] <= 2 ) {
306311 return BEST_FATTN_KERNEL_VEC;
307312 }
308313 return BEST_FATTN_KERNEL_WMMA_F16;
309314 }
310315
316+ // HIP decode path (Q->ne[1] == 1): fall through to generic HIP selection below (VEC/TILE),
317+ // with a guard to avoid selecting a TILE shape that has no config.
318+ if (hip_wmma_decode) {
319+ #if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
320+ // Mirror the ncols2 selection from launch_fattn_tile_switch_ncols2 to predict if
321+ // a multi-column TILE kernel (ncols2 != 1) would be chosen.
322+ const bool nvidia_arch = GGML_CUDA_CC_IS_NVIDIA (cc);
323+ const int gqa_limit = (nvidia_arch && gqa_ratio <= 4 ) ? 16 : INT_MAX;
324+ const bool use_gqa_opt = mask && max_bias == 0 .0f && Q->ne [1 ] <= gqa_limit && K->ne [1 ] % FATTN_KQ_STRIDE == 0 ;
325+
326+ int predicted_ncols2 = 1 ;
327+ if (V->ne [0 ] == 512 ) {
328+ if (use_gqa_opt && gqa_ratio % 16 == 0 ) predicted_ncols2 = 16 ;
329+ } else if (V->ne [0 ] <= 256 ) {
330+ if (use_gqa_opt && gqa_ratio % 8 == 0 ) predicted_ncols2 = 8 ;
331+ else if (use_gqa_opt && gqa_ratio % 4 == 0 ) predicted_ncols2 = 4 ;
332+ else if (use_gqa_opt && gqa_ratio % 2 == 0 ) predicted_ncols2 = 2 ;
333+ }
334+
335+ // Predict cols_per_block like launch_fattn_tile_switch_ncols1 does (HIP path):
336+ int predicted_cols_per_block = 2 ;
337+ if (predicted_ncols2 <= 2 ) {
338+ predicted_cols_per_block = 2 ;
339+ }
340+ if (predicted_ncols2 <= 4 && Q->ne [1 ] > 2 /predicted_ncols2) {
341+ predicted_cols_per_block = 4 ;
342+ }
343+ if (predicted_ncols2 <= 8 && Q->ne [1 ] > 4 /predicted_ncols2) {
344+ predicted_cols_per_block = 8 ;
345+ }
346+ if (Q->ne [1 ] > 8 /predicted_ncols2) {
347+ predicted_cols_per_block = 16 ;
348+ }
349+ if (Q->ne [1 ] > 16 /predicted_ncols2) {
350+ predicted_cols_per_block = 32 ;
351+ }
352+ if (V->ne [0 ] <= 128 && Q->ne [1 ] > 32 /predicted_ncols2) {
353+ predicted_cols_per_block = 64 ;
354+ }
355+
356+ const uint32_t cfg = ggml_cuda_fattn_tile_get_config ((int )Q->ne [0 ], (int )V->ne [0 ], predicted_cols_per_block, cc);
357+ if (predicted_ncols2 != 1 && cfg == 0 ) {
358+ return BEST_FATTN_KERNEL_VEC;
359+ }
360+ #endif // defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
361+ // Otherwise, fall through.
362+ }
363+
311364 // If there are no tensor cores available, use the generic tile kernel:
312365 if (can_use_vector_kernel) {
313366 if (!ggml_is_quantized (K->type ) && !ggml_is_quantized (V->type )) {
0 commit comments