Skip to content

Commit 81da919

Browse files
JohannesGaesslerggerganov
authored andcommitted
no vec for hs, no hs==256 ncols==32 for Volta
1 parent d59ac67 commit 81da919

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

ggml-cuda/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
#define CC_PASCAL 600
142142
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
143143
#define CC_VOLTA 700
144+
#define CC_AMPERE 800
144145
#define CC_OFFSET_AMD 1000000
145146
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
146147
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)

ggml-cuda/fattn.cu

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
463463
float scale;
464464
memcpy(&scale, KQV->op_params, sizeof(float));
465465

466-
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) {
466+
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) {
467467
const int nwarps = Q->ne[0] / WARP_SIZE;
468468
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
469469
const dim3 block_dim(WARP_SIZE, nwarps, 1);
470470
const int shmem = 0;
471471
switch (Q->ne[0]) {
472-
case 64:
473-
flash_attn_vec_ext_f16<64>
474-
<<<blocks_num, block_dim, shmem, main_stream>>> (
475-
(const char *) Q->data, // Query
476-
(const char *) K->data, // Key
477-
(const char *) V->data, // Value
478-
mask ? ((const char *) mask->data) : nullptr, // Mask
479-
(float *) KQV->data, // dst
480-
scale,
481-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
482-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
483-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
484-
Q->nb[1], Q->nb[2], Q->nb[3],
485-
K->nb[1], K->nb[2], K->nb[3],
486-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
487-
);
488-
break;
472+
// case 64:
473+
// flash_attn_vec_ext_f16<64>
474+
// <<<blocks_num, block_dim, shmem, main_stream>>> (
475+
// (const char *) Q->data, // Query
476+
// (const char *) K->data, // Key
477+
// (const char *) V->data, // Value
478+
// mask ? ((const char *) mask->data) : nullptr, // Mask
479+
// (float *) KQV->data, // dst
480+
// scale,
481+
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
482+
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
483+
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
484+
// Q->nb[1], Q->nb[2], Q->nb[3],
485+
// K->nb[1], K->nb[2], K->nb[3],
486+
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
487+
// );
488+
// break;
489489
// case 80:
490490
// flash_attn_vec_ext_f16<80>
491491
// <<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
503503
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
504504
// );
505505
// break;
506-
case 96:
507-
flash_attn_vec_ext_f16<96>
508-
<<<blocks_num, block_dim, shmem, main_stream>>> (
509-
(const char *) Q->data, // Query
510-
(const char *) K->data, // Key
511-
(const char *) V->data, // Value
512-
mask ? ((const char *) mask->data) : nullptr, // Mask
513-
(float *) KQV->data, // dst
514-
scale,
515-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
516-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
517-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
518-
Q->nb[1], Q->nb[2], Q->nb[3],
519-
K->nb[1], K->nb[2], K->nb[3],
520-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
521-
);
522-
break;
506+
// case 96:
507+
// flash_attn_vec_ext_f16<96>
508+
// <<<blocks_num, block_dim, shmem, main_stream>>> (
509+
// (const char *) Q->data, // Query
510+
// (const char *) K->data, // Key
511+
// (const char *) V->data, // Value
512+
// mask ? ((const char *) mask->data) : nullptr, // Mask
513+
// (float *) KQV->data, // dst
514+
// scale,
515+
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
516+
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
517+
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
518+
// Q->nb[1], Q->nb[2], Q->nb[3],
519+
// K->nb[1], K->nb[2], K->nb[3],
520+
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
521+
// );
522+
// break;
523523
// case 112:
524524
// flash_attn_vec_ext_f16<112>
525525
// <<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
583583
if (Q->ne[0] % 32 == 0) {
584584
if (Q->ne[1] >= 128 && Q->ne[0] <= 128) {
585585
cols_per_block = 64;
586-
} else if (Q->ne[1] >= 64) {
586+
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
587587
cols_per_block = 32;
588588
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
589589
cols_per_block = 16;

0 commit comments

Comments
 (0)