Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,18 +840,34 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*

float scale = (1.0f / sqrt((float)d_head));

// if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
// }
int kv_pad = 0;
//if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
//}
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));

bool can_use_flash_attn = true;
can_use_flash_attn = can_use_flash_attn && (
d_head == 64 ||
d_head == 80 ||
d_head == 96 ||
d_head == 112 ||
d_head == 128 ||
d_head == 256
);
#if 0
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check

// cuda max d_head seems to be 256, cpu does seem to work with 512
can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check
#else
if (can_use_flash_attn && L_k % 256 != 0) {
// TODO(Green-Sky): might be worth just padding by default
if (L_k == 77 || L_k == 4208 || L_k == 3952) {
kv_pad = GGML_PAD(L_k, 256) - L_k;
} else {
can_use_flash_attn = false;
}
}
#endif

if (mask != nullptr) {
// TODO(Green-Sky): figure out if we can bend t5 to work too
Expand All @@ -864,11 +880,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
ggml_tensor* kqv = nullptr;
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
if (can_use_flash_attn && flash_attn) {
// LOG_DEBUG("using flash attention");
//LOG_DEBUG(" uses flash attention");
if (kv_pad != 0) {
//LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
}
k = ggml_cast(ctx, k, GGML_TYPE_F16);

v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
if (kv_pad != 0) {
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
}
v = ggml_cast(ctx, v, GGML_TYPE_F16);

if (mask != nullptr) {
Expand Down
Loading