@@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
463463 float scale;
464464 memcpy (&scale, KQV->op_params , sizeof (float ));
465465
466- if (Q->ne [0 ] % WARP_SIZE == 0 && Q->ne [1 ] == 1 ) {
466+ if (Q->ne [0 ] % WARP_SIZE == 0 && Q->ne [0 ] >= 128 && Q-> ne [ 1 ] == 1 ) {
467467 const int nwarps = Q->ne [0 ] / WARP_SIZE;
468468 const dim3 blocks_num (Q->ne [1 ], Q->ne [2 ], Q->ne [3 ]);
469469 const dim3 block_dim (WARP_SIZE, nwarps, 1 );
470470 const int shmem = 0 ;
471471 switch (Q->ne [0 ]) {
472- case 64 :
473- flash_attn_vec_ext_f16<64 >
474- <<<blocks_num, block_dim, shmem, main_stream>>> (
475- (const char *) Q->data , // Query
476- (const char *) K->data , // Key
477- (const char *) V->data , // Value
478- mask ? ((const char *) mask->data ) : nullptr , // Mask
479- (float *) KQV->data , // dst
480- scale,
481- Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
482- K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
483- mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
484- Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
485- K->nb [1 ], K->nb [2 ], K->nb [3 ],
486- KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
487- );
488- break ;
472+ // case 64:
473+ // flash_attn_vec_ext_f16<64>
474+ // <<<blocks_num, block_dim, shmem, main_stream>>> (
475+ // (const char *) Q->data, // Query
476+ // (const char *) K->data, // Key
477+ // (const char *) V->data, // Value
478+ // mask ? ((const char *) mask->data) : nullptr, // Mask
479+ // (float *) KQV->data, // dst
480+ // scale,
481+ // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
482+ // K->ne[0], K->ne[1], K->ne[2], K->ne[3],
483+ // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
484+ // Q->nb[1], Q->nb[2], Q->nb[3],
485+ // K->nb[1], K->nb[2], K->nb[3],
486+ // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
487+ // );
488+ // break;
489489 // case 80:
490490 // flash_attn_vec_ext_f16<80>
491491 // <<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
503503 // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
504504 // );
505505 // break;
506- case 96 :
507- flash_attn_vec_ext_f16<96 >
508- <<<blocks_num, block_dim, shmem, main_stream>>> (
509- (const char *) Q->data , // Query
510- (const char *) K->data , // Key
511- (const char *) V->data , // Value
512- mask ? ((const char *) mask->data ) : nullptr , // Mask
513- (float *) KQV->data , // dst
514- scale,
515- Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
516- K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
517- mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
518- Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
519- K->nb [1 ], K->nb [2 ], K->nb [3 ],
520- KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
521- );
522- break ;
506+ // case 96:
507+ // flash_attn_vec_ext_f16<96>
508+ // <<<blocks_num, block_dim, shmem, main_stream>>> (
509+ // (const char *) Q->data, // Query
510+ // (const char *) K->data, // Key
511+ // (const char *) V->data, // Value
512+ // mask ? ((const char *) mask->data) : nullptr, // Mask
513+ // (float *) KQV->data, // dst
514+ // scale,
515+ // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
516+ // K->ne[0], K->ne[1], K->ne[2], K->ne[3],
517+ // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
518+ // Q->nb[1], Q->nb[2], Q->nb[3],
519+ // K->nb[1], K->nb[2], K->nb[3],
520+ // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
521+ // );
522+ // break;
523523 // case 112:
524524 // flash_attn_vec_ext_f16<112>
525525 // <<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
583583 if (Q->ne [0 ] % 32 == 0 ) {
584584 if (Q->ne [1 ] >= 128 && Q->ne [0 ] <= 128 ) {
585585 cols_per_block = 64 ;
586- } else if (Q->ne [1 ] >= 64 ) {
586+ } else if (Q->ne [1 ] >= 64 && (Q-> ne [ 0 ] <= 128 || ggml_cuda_info (). devices [ctx. device ]. cc >= CC_AMPERE) ) {
587587 cols_per_block = 32 ;
588588 } else if (Q->ne [1 ] >= 32 || Q->ne [0 ] % 32 != 0 ) {
589589 cols_per_block = 16 ;
0 commit comments