diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 9f6a4fef6..6a5a16f0a 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -840,18 +840,34 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); - // if (flash_attn) { - // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); - // } + int kv_pad = 0; + //if (flash_attn) { + // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + //} // is there anything oddly shaped?? ping Green-Sky if you can trip this assert GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); bool can_use_flash_attn = true; + can_use_flash_attn = can_use_flash_attn && ( + d_head == 64 || + d_head == 80 || + d_head == 96 || + d_head == 112 || + d_head == 128 || + d_head == 256 + ); +#if 0 can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0; - can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check - - // cuda max d_head seems to be 256, cpu does seem to work with 512 - can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check +#else + if (can_use_flash_attn && L_k % 256 != 0) { + // TODO(Green-Sky): might be worth just padding by default + if (L_k == 77 || L_k == 4208 || L_k == 3952) { + kv_pad = GGML_PAD(L_k, 256) - L_k; + } else { + can_use_flash_attn = false; + } + } +#endif if (mask != nullptr) { // TODO(Green-Sky): figure out if we can bend t5 to work too @@ -864,11 +880,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ggml_tensor* kqv = nullptr; // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); if (can_use_flash_attn && flash_attn) { - // LOG_DEBUG("using flash attention"); + //LOG_DEBUG(" uses flash attention"); + if (kv_pad != 0) { + //LOG_DEBUG(" padding k and v dim1 by %d", kv_pad); + k = ggml_pad(ctx, k, 0, kv_pad, 0, 0); + } k = ggml_cast(ctx, k, GGML_TYPE_F16); v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + if (kv_pad != 0) { + v = ggml_pad(ctx, v, 0, kv_pad, 0, 0); + } v = ggml_cast(ctx, v, GGML_TYPE_F16); if (mask != nullptr) {