Skip to content

Commit d7c2f27

Browse files
committed
try to fix some fattn inconsistencies
1 parent c12f9e3 commit d7c2f27

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,10 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
306306

307307
//kcpp: use wmma to fix cu11 incoherence
308308
if (ggml_cuda_should_use_wmma_fattn(cc) && (ggml_cuda_highest_compiled_arch(cc) <= GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_TURING)) {
309-
return BEST_FATTN_KERNEL_WMMA_F16;
309+
if(Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) //kcpp: these sizes not supported in wmma
310+
{
311+
return BEST_FATTN_KERNEL_WMMA_F16;
312+
}
310313
}
311314

312315
return BEST_FATTN_KERNEL_MMA_F16;
@@ -330,7 +333,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
330333
}
331334
}
332335
//kcpp: patch from previous version for my sanity. it worked before, idk it should work now.
333-
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
336+
if ((Q->ne[1] <= 8 || Q->ne[0] == 256) && can_use_vector_kernel) {
334337
return BEST_FATTN_KERNEL_VEC;
335338
}
336339

0 commit comments

Comments
 (0)