Skip to content

Commit 07f588d

Browse files
committed
Added CogVLM
1 parent e884d3d commit 07f588d

File tree

12 files changed

+416
-0
lines changed

12 files changed

+416
-0
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ extern "C" {
267267
int8_t * logits; // TODO: rename this to "output"
268268

269269
struct ggml_tensor * embd_tensor;
270+
struct ggml_tensor * cross_embd_tensor;
270271
} llama_batch;
271272

272273
enum llama_model_kv_override_type {

src/llama-arch.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6262
{ LLM_ARCH_GRANITE, "granite" },
6363
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
6464
{ LLM_ARCH_CHAMELEON, "chameleon" },
65+
{ LLM_ARCH_COGVLM, "cogvlm" },
6566
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
6667
{ LLM_ARCH_VISION_LLAVA, "llava" },
6768
{ LLM_ARCH_VISION_MOBILEVLM, "mobilevlm" },
@@ -1298,6 +1299,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
12981299
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
12991300
},
13001301
},
1302+
{
1303+
LLM_ARCH_COGVLM,
1304+
{
1305+
{ LLM_TENSOR_TOKEN_EMBD, "embed_tokens" },
1306+
{ LLM_TENSOR_OUTPUT_NORM, "norm" },
1307+
{ LLM_TENSOR_OUTPUT, "lm_head" },
1308+
{ LLM_TENSOR_ATTN_NORM, "layers.%d.input_layernorm" }, // input_norm_w
1309+
{ LLM_TENSOR_ATTN_TXT_QKV, "layers.%d.self_attn.language_expert_query_key_value" }, // language_qkv_w
1310+
{ LLM_TENSOR_ATTN_IMG_QKV, "layers.%d.self_attn.vision_expert_query_key_value" }, // vision_qkv_w
1311+
{ LLM_TENSOR_ATTN_TXT_DENSE, "layers.%d.self_attn.language_expert_dense" }, // language_dense_w
1312+
{ LLM_TENSOR_ATTN_IMG_DENSE, "layers.%d.self_attn.vision_expert_dense" }, // vision_dense_w
1313+
{ LLM_TENSOR_ATTN_NORM_2, "layers.%d.post_cross_attention_layernorm" }, // self_attn_norm_w
1314+
{ LLM_TENSOR_CROSS_ATTN_Q, "layers.%d.cross_attn.query" }, // cross_query_w
1315+
{ LLM_TENSOR_CROSS_ATTN_KV, "layers.%d.cross_attn.key_value" }, // cross_query_kv
1316+
{ LLM_TENSOR_CROSS_ATTN_DENSE, "layers.%d.cross_attn.dense" }, // cross_dense_w
1317+
{ LLM_TENSOR_FFN_NORM, "layers.%d.post_attention_layernorm" }, // attn_norm_w
1318+
{ LLM_TENSOR_FFN_TXT_UP, "layers.%d.mlp.language_mlp.up_proj" }, // language_up_proj_w
1319+
{ LLM_TENSOR_FFN_TXT_GATE, "layers.%d.mlp.language_mlp.gate_proj" }, // language_gate_proj_w
1320+
{ LLM_TENSOR_FFN_TXT_DOWN, "layers.%d.mlp.language_mlp.down_proj" }, // language_down_proj_w
1321+
{ LLM_TENSOR_FFN_IMG_UP, "layers.%d.mlp.vision_mlp.up_proj" }, // vision_up_proj_w
1322+
{ LLM_TENSOR_FFN_IMG_GATE, "layers.%d.mlp.vision_mlp.gate_proj" }, // vision_gate_proj_w
1323+
{ LLM_TENSOR_FFN_IMG_DOWN, "layers.%d.mlp.vision_mlp.down_proj" } // vision_down_proj_w
1324+
},
1325+
},
13011326
{
13021327
LLM_ARCH_WAVTOKENIZER_DEC,
13031328
{

src/llama-arch.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum llm_arch {
6666
LLM_ARCH_GRANITE_MOE,
6767
LLM_ARCH_CHAMELEON,
6868
LLM_ARCH_WAVTOKENIZER_DEC,
69+
LLM_ARCH_COGVLM,
6970
// vision
7071
LLM_ARCH_VISION_LLAVA,
7172
LLM_ARCH_VISION_MOBILEVLM,
@@ -354,6 +355,19 @@ enum llm_tensor {
354355
LLM_TENSOR_POS_NET_ATTN_K,
355356
LLM_TENSOR_POS_NET_ATTN_V,
356357
LLM_TENSOR_POS_NET_ATTN_OUT,
358+
LLM_TENSOR_ATTN_TXT_QKV,
359+
LLM_TENSOR_ATTN_IMG_QKV,
360+
LLM_TENSOR_ATTN_TXT_DENSE,
361+
LLM_TENSOR_ATTN_IMG_DENSE,
362+
LLM_TENSOR_CROSS_ATTN_Q,
363+
LLM_TENSOR_CROSS_ATTN_KV,
364+
LLM_TENSOR_CROSS_ATTN_DENSE,
365+
LLM_TENSOR_FFN_TXT_UP,
366+
LLM_TENSOR_FFN_TXT_GATE,
367+
LLM_TENSOR_FFN_TXT_DOWN,
368+
LLM_TENSOR_FFN_IMG_UP,
369+
LLM_TENSOR_FFN_IMG_GATE,
370+
LLM_TENSOR_FFN_IMG_DOWN,
357371
// vision
358372
LLM_TENSOR_V_MMPROJ,
359373
LLM_TENSOR_V_MMPROJ_FC,

src/llama-batch.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
3232
/*seq_id =*/ ubatch_seq_id.data(),
3333
/*output =*/ ubatch_output.data(),
3434
/*embd_tensor =*/ nullptr,
35+
/*cross_embd =*/ nullptr,
3536
};
3637
return ubatch;
3738
}
@@ -74,6 +75,9 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
7475
} else {
7576
ubatch.embd = nullptr;
7677
}
78+
if (batch->cross_embd) {
79+
ubatch.cross_embd = batch->cross_embd;
80+
}
7781
if (ubatch.equal_seqs) {
7882
for (size_t i = 0; i < length; ++i) {
7983
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
@@ -324,6 +328,7 @@ struct llama_batch llama_batch_get_one(
324328
/*seq_id =*/ nullptr,
325329
/*logits =*/ nullptr,
326330
/*embd_tensor =*/ nullptr,
331+
/*cross_embd =*/ nullptr,
327332
};
328333
}
329334

@@ -337,6 +342,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
337342
/*seq_id =*/ nullptr,
338343
/*logits =*/ nullptr,
339344
/*embd_tensor =*/ nullptr,
345+
/*cross_embd =*/ nullptr,
340346
};
341347

342348
if (embd) {
@@ -370,6 +376,7 @@ struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor,
370376
/*seq_id =*/ nullptr,
371377
/*logits =*/ nullptr,
372378
/*embd_tensor =*/ tensor,
379+
/*cross_embd =*/ nullptr,
373380
};
374381
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
375382
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);

src/llama-batch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct llama_ubatch {
2323
int8_t * output; // [n_tokens]
2424

2525
struct ggml_tensor * embd_tensor;
26+
struct ggml_tensor * cross_embd;
2627
};
2728

2829
struct llama_sbatch_seq {

src/llama-context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct llama_context {
2727
struct llama_sbatch sbatch; // TODO: revisit if needed
2828
struct llama_kv_cache kv_self;
2929
struct llama_adapter_cvec cvec;
30+
struct llama_cross_kv_cache kv_cross;
3031

3132
std::unordered_map<struct llama_adapter_lora *, float> lora;
3233

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct llama_hparams {
4141
uint32_t n_expert = 0;
4242
uint32_t n_expert_used = 0;
4343
uint32_t n_rel_attn_bkts = 0;
44+
uint32_t n_embd_cross = 1024; // For cross attention with different hidden size
4445

4546
// for WavTokenizer
4647
struct llama_hparams_posnet posnet;

src/llama-kv-cache.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,3 +716,74 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
716716
__func__, kv.used, used_cells);
717717
}
718718
}
719+
720+
// Cross attention KV cache
721+
bool llama_cross_kv_cache_init(struct llama_cross_kv_cache & cache,
722+
const llama_model & model,
723+
ggml_type type_k,
724+
ggml_type type_v,
725+
uint32_t n_elements,
726+
bool offload) {
727+
const struct llama_hparams & hparams = model.hparams;
728+
const int32_t n_layer = hparams.n_layer;
729+
730+
// create a context for each buffer type
731+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
732+
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
733+
auto it = ctx_map.find(buft);
734+
if (it == ctx_map.end()) {
735+
struct ggml_init_params params = {
736+
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
737+
/*.mem_buffer =*/ NULL,
738+
/*.no_alloc =*/ true,
739+
};
740+
ggml_context * ctx = ggml_init(params);
741+
if (!ctx) {
742+
return nullptr;
743+
}
744+
ctx_map[buft] = ctx;
745+
cache.ctxs.emplace_back(ctx);
746+
return ctx;
747+
}
748+
return it->second;
749+
};
750+
751+
for (int i = 0; i < n_layer; i++) {
752+
ggml_backend_buffer_type_t buft;
753+
if (offload) {
754+
auto * dev = model.dev_layer(i);
755+
buft = ggml_backend_dev_buffer_type(dev);
756+
} else {
757+
buft = ggml_backend_cpu_buffer_type();
758+
}
759+
ggml_context * ctx = ctx_for_buft(buft);
760+
761+
if (!ctx) {
762+
LLAMA_LOG_ERROR("%s: failed to initialize cross KV cache", __func__);
763+
return false;
764+
}
765+
766+
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_elements);
767+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_elements);
768+
ggml_format_name(k, "cross_cache_k_l%d", i);
769+
ggml_format_name(v, "cross_cache_v_l%d", i);
770+
cache.k_l.push_back(k);
771+
cache.v_l.push_back(v);
772+
}
773+
774+
for (auto it : ctx_map) {
775+
auto * buft = it.first;
776+
auto * ctx = it.second;
777+
778+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
779+
if (!buf) {
780+
LLAMA_LOG_ERROR("%s: failed to allocate buffer for cross kv cache\n", __func__);
781+
return false;
782+
}
783+
ggml_backend_buffer_clear(buf, 0);
784+
LLAMA_LOG_INFO("%s: %10s cross KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
785+
cache.bufs.emplace_back(buf);
786+
}
787+
788+
return true;
789+
}

src/llama-kv-cache.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,21 @@ struct llama_kv_slot_restorer {
216216
}
217217
};
218218

219+
// Simple cache that holds the computed K and V tensors
220+
// for each layer's cross attention calculation
221+
struct llama_cross_kv_cache {
222+
std::vector<struct ggml_tensor *> k_l;
223+
std::vector<struct ggml_tensor *> v_l;
224+
225+
std::vector<ggml_context_ptr> ctxs;
226+
std::vector<ggml_backend_buffer_ptr> bufs;
227+
228+
bool cache_filled;
229+
};
230+
231+
bool llama_cross_kv_cache_init(struct llama_cross_kv_cache & cache,
232+
const llama_model & model,
233+
ggml_type type_k,
234+
ggml_type type_v,
235+
uint32_t n_elements,
236+
bool offload);

src/llama-model.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12441244
default: type = LLM_TYPE_UNKNOWN;
12451245
}
12461246
} break;
1247+
case LLM_ARCH_COGVLM:
1248+
{
1249+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1250+
1251+
switch (hparams.n_layer) {
1252+
case 32: model.type = e_model::MODEL_7B; break;
1253+
default: model.type = e_model::MODEL_UNKNOWN;
1254+
}
1255+
}break;
12471256
case LLM_ARCH_WAVTOKENIZER_DEC:
12481257
{
12491258
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1443,6 +1452,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
14431452
const int64_t n_expert = hparams.n_expert;
14441453
const int64_t n_expert_used = hparams.n_expert_used;
14451454
const int64_t n_ctx_train = hparams.n_ctx_train;
1455+
const int64_t n_embd_cross = hparams.n_embd_cross;
14461456

14471457
if (n_expert > 0 && hparams.n_expert_used == 0) {
14481458
throw std::runtime_error("model has expert layers but no expert layers are used");
@@ -3372,6 +3382,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33723382
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
33733383
}
33743384
} break;
3385+
case LLM_ARCH_COGVLM:
3386+
{
3387+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
3388+
3389+
model.output_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
3390+
3391+
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
3392+
3393+
// Not supporting ctx_split
3394+
for (int i=0; i < n_layer; i++) {
3395+
ggml_context * ctx_layer = ctx_for_layer(i);
3396+
3397+
auto & layer = model.layers[i];
3398+
3399+
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
3400+
3401+
layer.wqkv_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_TXT_QKV, "weight", i), {n_embd, n_embd * 3});
3402+
layer.wqkv_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_IMG_QKV, "weight", i), {n_embd, n_embd * 3});
3403+
layer.wdense_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_TXT_DENSE, "weight", i), {n_embd, n_embd});
3404+
layer.wdense_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_IMG_DENSE, "weight", i), {n_embd, n_embd});
3405+
3406+
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
3407+
3408+
layer.wq_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_cross});
3409+
// The input dimension is the number of dimensions from the cross vision encoder
3410+
// it might not be guaranteed that this is the same as the number of dimensions
3411+
// in the cogvlm attention calculation
3412+
layer.wkv_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CROSS_ATTN_KV, "weight", i), {n_embd_cross, n_embd_cross * 2});
3413+
layer.wdense_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CROSS_ATTN_DENSE, "weight", i), {n_embd_cross, n_embd});
3414+
3415+
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
3416+
3417+
layer.ffn_gate_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_TXT_GATE, "weight", i), {n_embd, n_ff});
3418+
layer.ffn_down_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_TXT_DOWN, "weight", i), {n_ff, n_embd});
3419+
layer.ffn_up_txt = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_TXT_UP, "weight", i), {n_embd, n_ff});
3420+
layer.ffn_gate_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_IMG_GATE, "weight", i), {n_embd, n_ff});
3421+
layer.ffn_down_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_IMG_DOWN, "weight", i), {n_ff, n_embd});
3422+
layer.ffn_up_img = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_IMG_UP, "weight", i), {n_embd, n_ff});
3423+
}
3424+
} break;
33753425
case LLM_ARCH_WAVTOKENIZER_DEC:
33763426
{
33773427
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);

0 commit comments

Comments
 (0)