Skip to content

Commit 20e429d

Browse files
committed
add flash_attn
1 parent 6cccad2 commit 20e429d

File tree

4 files changed

+63
-9
lines changed

4 files changed

+63
-9
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
5050
case 128:
5151
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
5252
break;
53+
case 192:
54+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, ncols2>(ctx, dst);
55+
break;
5356
case 256:
5457
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
5558
break;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3248,7 +3248,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
32483248
}
32493249

32503250
static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
3251-
const int min_batch_size = 9999999;
3251+
const int min_batch_size = 32;
32523252

32533253
return get_op_batch_size(op) >= min_batch_size;
32543254

ggml/src/ggml-cuda/pad.cu

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
2525
}
2626
}
2727

28+
static __global__ void pad_f16(const half * x, half * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
29+
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
30+
// blockIdx.y: idx of ne1
31+
// blockIDx.x: idx of ne0 / BLOCK_SIZE
32+
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
33+
if (nidx >= ne0) {
34+
return;
35+
}
36+
37+
// operation
38+
int offset_dst =
39+
nidx +
40+
blockIdx.y * ne0 +
41+
blockIdx.z * ne0 * gridDim.y;
42+
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
43+
int offset_src =
44+
nidx +
45+
blockIdx.y * ne00 +
46+
blockIdx.z * ne00 * ne01;
47+
dst[offset_dst] = x[offset_src];
48+
} else {
49+
dst[offset_dst] = 0.0f;
50+
}
51+
}
52+
2853
static void pad_f32_cuda(const float * x, float * dst,
2954
const int ne00, const int ne01, const int ne02, const int ne03,
3055
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
@@ -33,17 +58,35 @@ static void pad_f32_cuda(const float * x, float * dst,
3358
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
3459
}
3560

61+
static void pad_f16_cuda(const half * x, half * dst,
62+
const int ne00, const int ne01, const int ne02, const int ne03,
63+
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
64+
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
65+
dim3 gridDim(num_blocks, ne1, ne2*ne3);
66+
pad_f16<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
67+
}
68+
3669
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3770
const ggml_tensor * src0 = dst->src[0];
3871
const float * src0_d = (const float *)src0->data;
3972
float * dst_d = (float *)dst->data;
4073
cudaStream_t stream = ctx.stream();
4174

42-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
43-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
75+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
76+
GGML_ASSERT(dst->type == src0->type);
4477
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
4578

46-
pad_f32_cuda(src0_d, dst_d,
47-
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
48-
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
79+
if (src0->type == GGML_TYPE_F32) {
80+
const float * src0_d = (const float *)src0->data;
81+
float * dst_d = (float *)dst->data;
82+
pad_f32_cuda(src0_d, dst_d,
83+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
84+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
85+
} else {
86+
const half * src0_d = (const half *)src0->data;
87+
half * dst_d = (half *)dst->data;
88+
pad_f16_cuda(src0_d, dst_d,
89+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
90+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
91+
}
4992
}

src/llama.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,16 @@ static struct ggml_tensor * llm_build_kqv(
588588
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
589589
0);
590590
cb(v, "v", il);
591+
592+
struct ggml_tensor * padded_v = v;
593+
int64_t n_embd_head_v_out = n_embd_head_v;
594+
if (n_embd_head_v < n_embd_head_k) {
595+
padded_v = ggml_pad(ctx, v, 0, k->ne[0] - v->ne[1], 0, 0);
596+
cb(padded_v, "padded_v", il);
597+
n_embd_head_v_out = n_embd_head_k;
598+
}
591599

592-
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
600+
cur = ggml_flash_attn_ext(ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
593601
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
594602

595603
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
@@ -9567,8 +9575,8 @@ struct llama_context * llama_init_from_model(
95679575
params.flash_attn = false;
95689576
}
95699577

9570-
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
9571-
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
9578+
if (params.flash_attn && model->hparams.n_embd_head_k < model->hparams.n_embd_head_v) {
9579+
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k >= n_embd_head_v - forcing off\n", __func__);
95729580
params.flash_attn = false;
95739581
}
95749582

0 commit comments

Comments
 (0)