Skip to content

Commit 4c7acaf

Browse files
committed
Fixes to compile
1 parent 07f588d commit 4c7acaf

File tree

3 files changed

+49
-31
lines changed

3 files changed

+49
-31
lines changed

include/llama.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ extern "C" {
267267
int8_t * logits; // TODO: rename this to "output"
268268

269269
struct ggml_tensor * embd_tensor;
270-
struct ggml_tensor * cross_embd_tensor;
270+
struct ggml_tensor * cross_embd;
271271
} llama_batch;
272272

273273
enum llama_model_kv_override_type {
@@ -544,6 +544,9 @@ extern "C" {
544544
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
545545
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
546546

547+
// Returns true if the model has a image attention KV cache
548+
LLAMA_API bool llama_model_has_cross_kv(const struct llama_model * model);
549+
547550
// Returns 0 on success
548551
LLAMA_API uint32_t llama_model_quantize(
549552
const char * fname_inp,

src/llama-model.cpp

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,8 +1249,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12491249
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
12501250

12511251
switch (hparams.n_layer) {
1252-
case 32: model.type = e_model::MODEL_7B; break;
1253-
default: model.type = e_model::MODEL_UNKNOWN;
1252+
case 32: type = LLM_TYPE_7B; break;
1253+
default: type = LLM_TYPE_UNKNOWN;
12541254
}
12551255
}break;
12561256
case LLM_ARCH_WAVTOKENIZER_DEC:
@@ -3384,42 +3384,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33843384
} break;
33853385
case LLM_ARCH_COGVLM:
33863386
{
3387-
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
3387+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
33883388

3389-
model.output_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
3389+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
33903390

3391-
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
3391+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
33923392

33933393
// Not supporting ctx_split
33943394
for (int i=0; i < n_layer; i++) {
3395-
ggml_context * ctx_layer = ctx_for_layer(i);
3396-
3397-
auto & layer = model.layers[i];
3395+
auto & layer = layers[i];
33983396

3399-
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
3397+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
34003398

3401-
layer.wqkv_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_TXT_QKV, "weight", i), {n_embd, n_embd * 3});
3402-
layer.wqkv_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_IMG_QKV, "weight", i), {n_embd, n_embd * 3});
3403-
layer.wdense_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_TXT_DENSE, "weight", i), {n_embd, n_embd});
3404-
layer.wdense_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_IMG_DENSE, "weight", i), {n_embd, n_embd});
3399+
layer.wqkv_txt = create_tensor(tn(LLM_TENSOR_ATTN_TXT_QKV, "weight", i), {n_embd, n_embd * 3}, 0);
3400+
layer.wqkv_img = create_tensor(tn(LLM_TENSOR_ATTN_IMG_QKV, "weight", i), {n_embd, n_embd * 3}, 0);
3401+
layer.wdense_txt = create_tensor(tn(LLM_TENSOR_ATTN_TXT_DENSE, "weight", i), {n_embd, n_embd}, 0);
3402+
layer.wdense_img = create_tensor(tn(LLM_TENSOR_ATTN_IMG_DENSE, "weight", i), {n_embd, n_embd}, 0);
34053403

3406-
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
3404+
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
34073405

3408-
layer.wq_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_cross});
3406+
layer.wq_cross = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_cross}, 0);
34093407
// The input dimension is the number of dimensions from the cross vision encoder
34103408
// it might not be guaranteed that this is the same as the number of dimensions
34113409
// in the cogvlm attention calculation
3412-
layer.wkv_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CROSS_ATTN_KV, "weight", i), {n_embd_cross, n_embd_cross * 2});
3413-
layer.wdense_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CROSS_ATTN_DENSE, "weight", i), {n_embd_cross, n_embd});
3410+
layer.wkv_cross = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_KV, "weight", i), {n_embd_cross, n_embd_cross * 2}, 0);
3411+
layer.wdense_cross = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_DENSE, "weight", i), {n_embd_cross, n_embd}, 0);
34143412

3415-
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
3413+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
34163414

3417-
layer.ffn_gate_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_TXT_GATE, "weight", i), {n_embd, n_ff});
3418-
layer.ffn_down_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_TXT_DOWN, "weight", i), {n_ff, n_embd});
3419-
layer.ffn_up_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_TXT_UP, "weight", i), {n_embd, n_ff});
3420-
layer.ffn_gate_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_IMG_GATE, "weight", i), {n_embd, n_ff});
3421-
layer.ffn_down_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_IMG_DOWN, "weight", i), {n_ff, n_embd});
3422-
layer.ffn_up_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_IMG_UP, "weight", i), {n_embd, n_ff});
3415+
layer.ffn_gate_txt = create_tensor(tn(LLM_TENSOR_FFN_TXT_GATE, "weight", i), {n_embd, n_ff}, 0);
3416+
layer.ffn_down_txt = create_tensor(tn(LLM_TENSOR_FFN_TXT_DOWN, "weight", i), {n_ff, n_embd}, 0);
3417+
layer.ffn_up_txt = create_tensor(tn(LLM_TENSOR_FFN_TXT_UP, "weight", i), {n_embd, n_ff}, 0);
3418+
layer.ffn_gate_img = create_tensor(tn(LLM_TENSOR_FFN_IMG_GATE, "weight", i), {n_embd, n_ff}, 0);
3419+
layer.ffn_down_img = create_tensor(tn(LLM_TENSOR_FFN_IMG_DOWN, "weight", i), {n_ff, n_embd}, 0);
3420+
layer.ffn_up_img = create_tensor(tn(LLM_TENSOR_FFN_IMG_UP, "weight", i), {n_embd, n_ff}, 0);
34233421
}
34243422
} break;
34253423
case LLM_ARCH_WAVTOKENIZER_DEC:
@@ -4170,6 +4168,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
41704168
case LLM_ARCH_GRANITE:
41714169
case LLM_ARCH_GRANITE_MOE:
41724170
case LLM_ARCH_CHAMELEON:
4171+
case LLM_ARCH_COGVLM:
41734172
return LLAMA_ROPE_TYPE_NORM;
41744173

41754174
// the pairs of head values are offset by n_rot/2
@@ -4309,3 +4308,10 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
43094308
default: return false;
43104309
}
43114310
}
4311+
4312+
bool llama_model_has_cross_kv(const struct llama_model * model) {
4313+
switch (model->arch) {
4314+
case LLM_ARCH_COGVLM: return true;
4315+
default: return false;
4316+
}
4317+
}

src/llama.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -713,14 +713,14 @@ static struct ggml_tensor * llm_build_kv(
713713
// cross attention KV cache
714714
static struct ggml_tensor * llm_build_cross_kv(
715715
struct ggml_context * ctx,
716-
struct llama_context * lctx,
716+
struct llama_context & lctx,
717717
struct ggml_tensor * qcur,
718718
struct ggml_tensor * kcur,
719719
struct ggml_tensor * vcur,
720720
struct ggml_cgraph * graph,
721721
int64_t il
722722
) {
723-
llama_cross_kv_cache & kv = lctx->kv_cross;
723+
llama_cross_kv_cache & kv = lctx.kv_cross;
724724

725725
// Q has dimensions K, H, L, B
726726
// K = hidden dimension per head
@@ -8187,8 +8187,8 @@ struct llm_build_context {
81878187

81888188
// Multiplied directly to Q
81898189
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
8190-
const float cross_attn_scale = 1.0f / sqrtf(float(hparams.n_embd_cross / hparams.n_head()));
81918190

8191+
struct ggml_tensor * cur;
81928192
struct ggml_tensor * inpL;
81938193
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
81948194

@@ -9495,7 +9495,7 @@ static void llama_kv_cache_update_impl(struct llama_context & lctx) {
94959495
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
94969496
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
94979497
llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
9498-
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
9498+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
94999499
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
95009500

95019501
// initialize scheduler with the worst-case graph
@@ -9963,6 +9963,15 @@ struct llama_context * llama_init_from_model(
99639963
return nullptr;
99649964
}
99659965

9966+
if (llama_model_has_cross_kv(model)) {
9967+
// TODO: Add parameter for cross kv cache size
9968+
if (!llama_cross_kv_cache_init(ctx->kv_cross, ctx->model, type_k, type_v, 1024 * 6400, cparams.offload_kqv)) {
9969+
LLAMA_LOG_ERROR("%s: llama_cross_kv_cache_init() failed\n", __func__);
9970+
llama_free(ctx);
9971+
return nullptr;
9972+
}
9973+
}
9974+
99669975
{
99679976
size_t memory_size_k = 0;
99689977
size_t memory_size_v = 0;
@@ -10058,7 +10067,7 @@ struct llama_context * llama_init_from_model(
1005810067
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
1005910068
llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
1006010069

10061-
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
10070+
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
1006210071
ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
1006310072

1006410073
// reserve pp graph first so that buffers are only allocated once
@@ -10067,7 +10076,7 @@ struct llama_context * llama_init_from_model(
1006710076
int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
1006810077

1006910078
// reserve with tg graph to get the number of splits and nodes
10070-
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
10079+
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
1007110080
ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true);
1007210081
ggml_backend_sched_reserve(ctx->sched.get(), gf_tg);
1007310082
int n_splits_tg = ggml_backend_sched_get_n_splits(ctx->sched.get());

0 commit comments

Comments
 (0)