File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -306,7 +306,10 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
306306
307307 // kcpp: use wmma to fix cu11 incoherence
308308 if (ggml_cuda_should_use_wmma_fattn (cc) && (ggml_cuda_highest_compiled_arch (cc) <= GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_TURING)) {
309- return BEST_FATTN_KERNEL_WMMA_F16;
309+ if (Q->ne [0 ] != 40 && Q->ne [0 ] != 72 && Q->ne [0 ] != 576 ) // kcpp: these sizes not supported in wmma
310+ {
311+ return BEST_FATTN_KERNEL_WMMA_F16;
312+ }
310313 }
311314
312315 return BEST_FATTN_KERNEL_MMA_F16;
@@ -330,7 +333,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
330333 }
331334 }
332335 // kcpp: patch from previous version for my sanity. it worked before, idk it should work now.
333- if (Q->ne [1 ] <= 8 || Q->ne [0 ] == 256 ) {
336+ if (( Q->ne [1 ] <= 8 || Q->ne [0 ] == 256 ) && can_use_vector_kernel ) {
334337 return BEST_FATTN_KERNEL_VEC;
335338 }
336339
You can’t perform that action at this time.
0 commit comments