Skip to content

Commit 1bfc9af

Browse files
committed
Revert "try fix fattn again, porting some older code. the cc detection is not working well, so its hacky"
This reverts commit 7b04191.
1 parent d1907a0 commit 1bfc9af

File tree

1 file changed

+2
-14
lines changed

1 file changed

+2
-14
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
298298
const int warp_size = ggml_cuda_info().devices[device].warp_size;
299299
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
300300

301-
#if defined(GGML_HIP_ROCWMMA_FATTN)
302-
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { //kcpp: fix for rocwmma
303-
return BEST_FATTN_KERNEL_WMMA_F16;
304-
}
305-
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
306-
307301
switch (K->ne[0]) {
308302
case 64:
309303
case 128:
@@ -421,21 +415,15 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
421415
return BEST_FATTN_KERNEL_WMMA_F16;
422416
}
423417

424-
//kcpp: always force WMMA for Turing and Volta if above check fails, fix "FlashAttention without tensor cores only supports head sizes 64 and 128."
425-
if (cc == GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_VOLTA) {
418+
//kcpp: always force WMMA for older gpus, fix issues like "FlashAttention without tensor cores only supports head sizes 64 and 128."
419+
if (ggml_cuda_highest_compiled_arch(cc) <= GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_TURING) {
426420
return BEST_FATTN_KERNEL_WMMA_F16;
427421
}
428422

429423
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
430424
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
431-
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
432-
return BEST_FATTN_KERNEL_VEC_F16; //kcpp: patch from previous version for my sanity. it worked before, idk it should work now.
433-
}
434425
return BEST_FATTN_KERNEL_TILE_F16;
435426
}
436-
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
437-
return BEST_FATTN_KERNEL_VEC_F32; //kcpp: patch from previous version for my sanity. it worked before, idk it should work now.
438-
}
439427
return BEST_FATTN_KERNEL_TILE_F32;
440428
}
441429

0 commit comments

Comments
 (0)