Skip to content

Commit a45e1cd

Browse files
committed
HIP: use WMMA for prefill only; fix decode regression by enabling TILE and adding a safe fallback\n\n- Do not select WMMA for decode on HIP; fall through to VEC/TILE\n- Remove WMMA TILE pruning on HIP to avoid device traps; keep for CUDA WMMA\n- Add decode-time guard: if predicted TILE split has no config, select VEC\n- Remove ad-hoc env overrides and debug prints
1 parent a3c9d1d commit a45e1cd

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,9 +721,12 @@ static __global__ void flash_attn_tile(
721721

722722
// Skip unused kernel variants for faster compilation:
723723

724+
// Optionally disable pruning to keep all TILE variants for testing.
725+
#if !defined(GGML_USE_HIP)
724726
if (
725727
#ifdef GGML_USE_WMMA_FATTN
726-
(ncols2 != 1 && DV != 40 && DV != 512) ||
728+
// On CUDA WMMA builds, prune some TILE variants to reduce compile time/binary size.
729+
(ncols2 != 1 && DV != 40 && DV != 64 && DV != 128 && DV != 256 && DV != 512) ||
727730
#endif // GGML_USE_WMMA_FATTN
728731
(use_logit_softcap && !(DV == 128 || DV == 256))
729732
) {
@@ -739,6 +742,7 @@ static __global__ void flash_attn_tile(
739742
NO_DEVICE_CODE;
740743
return;
741744
}
745+
#endif // !defined(GGML_USE_HIP)
742746

743747
static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
744748

ggml/src/ggml-cuda/fattn.cu

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,13 +301,66 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
301301
}
302302

303303
// Use the WMMA kernel if possible:
304-
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
304+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
305+
const bool hip_wmma_decode = Q->ne[1] == 1;
306+
#else
307+
const bool hip_wmma_decode = false;
308+
#endif
309+
if (!hip_wmma_decode && ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
305310
if (can_use_vector_kernel && Q->ne[1] <= 2) {
306311
return BEST_FATTN_KERNEL_VEC;
307312
}
308313
return BEST_FATTN_KERNEL_WMMA_F16;
309314
}
310315

316+
// HIP decode path (Q->ne[1] == 1): fall through to generic HIP selection below (VEC/TILE),
317+
// with a guard to avoid selecting a TILE shape that has no config.
318+
if (hip_wmma_decode) {
319+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
320+
// Mirror the ncols2 selection from launch_fattn_tile_switch_ncols2 to predict if
321+
// a multi-column TILE kernel (ncols2 != 1) would be chosen.
322+
const bool nvidia_arch = GGML_CUDA_CC_IS_NVIDIA(cc);
323+
const int gqa_limit = (nvidia_arch && gqa_ratio <= 4) ? 16 : INT_MAX;
324+
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
325+
326+
int predicted_ncols2 = 1;
327+
if (V->ne[0] == 512) {
328+
if (use_gqa_opt && gqa_ratio % 16 == 0) predicted_ncols2 = 16;
329+
} else if (V->ne[0] <= 256) {
330+
if (use_gqa_opt && gqa_ratio % 8 == 0) predicted_ncols2 = 8;
331+
else if (use_gqa_opt && gqa_ratio % 4 == 0) predicted_ncols2 = 4;
332+
else if (use_gqa_opt && gqa_ratio % 2 == 0) predicted_ncols2 = 2;
333+
}
334+
335+
// Predict cols_per_block like launch_fattn_tile_switch_ncols1 does (HIP path):
336+
int predicted_cols_per_block = 2;
337+
if (predicted_ncols2 <= 2) {
338+
predicted_cols_per_block = 2;
339+
}
340+
if (predicted_ncols2 <= 4 && Q->ne[1] > 2/predicted_ncols2) {
341+
predicted_cols_per_block = 4;
342+
}
343+
if (predicted_ncols2 <= 8 && Q->ne[1] > 4/predicted_ncols2) {
344+
predicted_cols_per_block = 8;
345+
}
346+
if (Q->ne[1] > 8/predicted_ncols2) {
347+
predicted_cols_per_block = 16;
348+
}
349+
if (Q->ne[1] > 16/predicted_ncols2) {
350+
predicted_cols_per_block = 32;
351+
}
352+
if (V->ne[0] <= 128 && Q->ne[1] > 32/predicted_ncols2) {
353+
predicted_cols_per_block = 64;
354+
}
355+
356+
const uint32_t cfg = ggml_cuda_fattn_tile_get_config((int)Q->ne[0], (int)V->ne[0], predicted_cols_per_block, cc);
357+
if (predicted_ncols2 != 1 && cfg == 0) {
358+
return BEST_FATTN_KERNEL_VEC;
359+
}
360+
#endif // defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
361+
// Otherwise, fall through.
362+
}
363+
311364
// If there are no tensor cores available, use the generic tile kernel:
312365
if (can_use_vector_kernel) {
313366
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {

0 commit comments

Comments
 (0)