Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3148,6 +3148,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
// FIXME: this is not accurate, the flash attn implementation only has kernels for a limited number of configurations,
// which varies depending on too many factors to duplicate here.
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif
Expand Down
158 changes: 105 additions & 53 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2777,7 +2777,7 @@ struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
std::vector<bool> v_trans_l; // the value tensor is transposed (per layer)

// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
Expand Down Expand Up @@ -3395,7 +3395,7 @@ static int llama_get_device_count(const llama_model & model) {
}

template<typename F>
static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, const F & fn) {
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead()*8,
/*.mem_buffer =*/ NULL,
Expand Down Expand Up @@ -3446,16 +3446,12 @@ static bool llama_kv_cache_init(
uint32_t kv_size,
bool offload) {
const llama_model & model = ctx->model;
const llama_cparams & cparams = ctx->cparams;

const struct llama_hparams & hparams = model.hparams;

const int64_t n_layer = hparams.n_layer;
const int64_t n_layer = hparams.n_layer;

cache.has_shift = false;

cache.recurrent = llama_model_is_recurrent(&model);
cache.v_trans = !cache.recurrent && !cparams.flash_attn;

cache.head = 0;
cache.size = kv_size;
Expand Down Expand Up @@ -9699,10 +9695,7 @@ static struct ggml_tensor * llm_build_kqv(

struct ggml_tensor * cur;

if (cparams.flash_attn) {
GGML_UNUSED(model);
GGML_UNUSED(n_ctx);

if (kv.v_trans_l[il]) {
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
Expand Down Expand Up @@ -19400,6 +19393,10 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}

if (llama_model_is_recurrent(model)) {
params.flash_attn = false;
}

if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
Expand Down Expand Up @@ -19495,23 +19492,98 @@ struct llama_context * llama_new_context_with_model(
// build worst-case graph for encoder if a model contains encoder
ctx->is_encoding = llama_model_has_encoder(model);

uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
if (!hparams.vocab_only) {
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;

// Mamba only needs a constant number of KV cache cells per sequence
if (llama_model_is_recurrent(model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}
// Mamba only needs a constant number of KV cache cells per sequence
if (llama_model_is_recurrent(model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}

GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);

if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
}

// find which layers can use flash attention
std::vector<bool> & flash_attn_layers = ctx->kv_self.v_trans_l;
flash_attn_layers.resize(hparams.n_layer, false);
if (cparams.flash_attn) {
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
ggml_backend_dev_t layer_dev = model->dev_layer.at(il).dev;
ggml_backend_buffer_type_t layer_buft = ggml_backend_dev_buffer_type(layer_dev);

auto & kv = ctx->kv_self;
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head(il);
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
const int64_t n_embd_head = hparams.n_embd_head_v;
int n_kv = 128;
int n_tokens = 128;

bool supported = buft_supported(layer_buft, layer_dev, [&](ggml_context * ctx) -> ggml_tensor * {
ggml_tensor * kq_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
float kq_scale = 1.0f;
ggml_tensor * q_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, n_tokens);
ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
0);

// split cached v into n_head heads (not transposed)
ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
0);

ggml_tensor * cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);

ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
return cur;
});

LLAMA_LOG_INFO("%s: layer %2d %s flash_attn %s\n", __func__, il, ggml_backend_dev_name(layer_dev), supported ? "supported" : "not supported");

flash_attn_layers[il] = supported;
}
}

{
size_t memory_size_k = 0;
size_t memory_size_v = 0;

for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k);
}

for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v);
}

LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}

if (!hparams.vocab_only) {
// GPU backends
for (auto * dev : model->devices) {
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
Expand Down Expand Up @@ -19558,30 +19630,6 @@ struct llama_context * llama_new_context_with_model(
}
}

if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
}

{
size_t memory_size_k = 0;
size_t memory_size_v = 0;

for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k);
}

for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v);
}

LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}

// graph outputs buffer
{
// resized during inference when a batch uses more outputs
Expand Down Expand Up @@ -20330,7 +20378,8 @@ struct llama_data_write {
const struct llama_kv_cache & kv_self = ctx->kv_self;
const struct llama_hparams & hparams = ctx->model.hparams;

const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
// FIXME
const uint32_t v_trans = kv_self.v_trans_l.at(0) ? 1 : 0;
const uint32_t n_layer = hparams.n_layer;

write(&v_trans, sizeof(v_trans));
Expand Down Expand Up @@ -20359,7 +20408,8 @@ struct llama_data_write {
}
}

if (!kv_self.v_trans) {
// FIXME
if (!kv_self.v_trans_l.at(0)) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();

Expand Down Expand Up @@ -20652,7 +20702,8 @@ struct llama_data_read {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
return false;
}
if (kv_self.v_trans != (bool) v_trans) {
// FIXME
if (kv_self.v_trans_l.at(0) != (bool) v_trans) {
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
return false;
}
Expand Down Expand Up @@ -20685,7 +20736,8 @@ struct llama_data_read {
}
}

if (!kv_self.v_trans) {
// FIXME
if (!kv_self.v_trans_l.at(0)) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();

Expand Down
Loading