Skip to content

Commit c4cf462

Browse files
committed
Saving version that runs just like before rebase
1 parent 4c7acaf commit c4cf462

File tree

5 files changed

+31
-9
lines changed

5 files changed

+31
-9
lines changed

common/arg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,14 +1406,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14061406
[](common_params & params, const std::string & value) {
14071407
params.mmproj = value;
14081408
}
1409-
).set_examples({LLAMA_EXAMPLE_LLAVA}));
1409+
).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_COGAGENT}));
14101410
add_opt(common_arg(
14111411
{"--image"}, "FILE",
14121412
"path to an image file. use with multimodal models. Specify multiple times for batching",
14131413
[](common_params & params, const std::string & value) {
14141414
params.image.emplace_back(value);
14151415
}
1416-
).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_VISION}));
1416+
).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_VISION, LLAMA_EXAMPLE_COGAGENT}));
14171417
if (llama_supports_rpc()) {
14181418
add_opt(common_arg(
14191419
{"--rpc"}, "SERVERS",

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ enum llama_example {
8181
LLAMA_EXAMPLE_PARALLEL,
8282
LLAMA_EXAMPLE_TTS,
8383
LLAMA_EXAMPLE_VISION,
84+
LLAMA_EXAMPLE_COGAGENT,
8485

8586
LLAMA_EXAMPLE_COUNT,
8687
};

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ if (EMSCRIPTEN)
1818
else()
1919
add_subdirectory(batched-bench)
2020
add_subdirectory(batched)
21+
add_subdirectory(cogagent)
2122
add_subdirectory(embedding)
2223
add_subdirectory(eval-callback)
2324

src/llama-arch.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,19 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
15861586
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15871587
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15881588
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1589+
{LLM_TENSOR_ATTN_TXT_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1590+
{LLM_TENSOR_ATTN_IMG_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1591+
{LLM_TENSOR_ATTN_TXT_DENSE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1592+
{LLM_TENSOR_ATTN_IMG_DENSE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1593+
{LLM_TENSOR_CROSS_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1594+
{LLM_TENSOR_CROSS_ATTN_KV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1595+
{LLM_TENSOR_CROSS_ATTN_DENSE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1596+
{LLM_TENSOR_FFN_TXT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1597+
{LLM_TENSOR_FFN_TXT_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1598+
{LLM_TENSOR_FFN_TXT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1599+
{LLM_TENSOR_FFN_IMG_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1600+
{LLM_TENSOR_FFN_IMG_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1601+
{LLM_TENSOR_FFN_IMG_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15891602
// vision
15901603
{LLM_TENSOR_V_MMPROJ, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
15911604
{LLM_TENSOR_V_MMPROJ_MLP, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},

src/llama.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,17 @@ static struct ggml_tensor * llm_build_inp_embd(
158158
}
159159

160160
static struct ggml_tensor * llm_build_cross_embd(
161+
struct ggml_context * ctx,
161162
const llama_ubatch & ubatch
162163
) {
163-
struct ggml_tensor * cross_embd = ubatch.cross_embd;
164+
struct ggml_tensor * cross_embd;
165+
if (ubatch.cross_embd) {
166+
cross_embd = ubatch.cross_embd;
167+
} else {
168+
printf("ubatch does not have cross_embd tensor, "
169+
"building graph with placeholder instead\n");
170+
cross_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1024, 6400);
171+
}
164172
ggml_set_input(cross_embd);
165173
return cross_embd;
166174
}
@@ -727,7 +735,7 @@ static struct ggml_tensor * llm_build_cross_kv(
727735
// H = number of heads
728736
// L = number of tokens
729737
// B = batch size
730-
const int64_t num_heads = qcur->ne[1];
738+
const int64_t num_heads = lctx.model.hparams.n_head();
731739
const float cross_attn_scale = 1.0f / sqrtf(float(qcur->ne[0]));
732740
// Only add the computation of K and V if
733741
// the cache doesn't already have the data
@@ -744,10 +752,8 @@ static struct ggml_tensor * llm_build_cross_kv(
744752
// Compute cross attention score
745753
struct ggml_tensor * q = ggml_reshape_4d(ctx, qcur, qcur->ne[0] / num_heads,
746754
num_heads, qcur->ne[1], qcur->ne[2]);
747-
k = ggml_reshape_4d(ctx, k, kcur->ne[0] / num_heads, num_heads,
748-
kcur->ne[1], kcur->ne[2]);
749-
v = ggml_reshape_4d(ctx, v, vcur->ne[0] / num_heads, num_heads,
750-
vcur->ne[1], vcur->ne[2]);
755+
k = ggml_reshape_3d(ctx, k, 1024 / num_heads, num_heads, 6400);
756+
v = ggml_reshape_3d(ctx, v, 1024 / num_heads, num_heads, 6400);
751757
q = ggml_permute(ctx, q, 0, 2, 1, 3);
752758
k = ggml_permute(ctx, k, 0, 2, 1, 3);
753759
v = ggml_permute(ctx, v, 1, 2, 0, 3);
@@ -8194,7 +8200,7 @@ struct llm_build_context {
81948200

81958201
// Get the cross vision encoder embedded picture
81968202
struct ggml_tensor * cross_embd;
8197-
cross_embd = llm_build_cross_embd(ubatch);
8203+
cross_embd = llm_build_cross_embd(ctx0, ubatch);
81988204

81998205
// Assuming text tokens are in ubatch.token, and image tokens are in ubatch.embd_tensor
82008206
bool batch_is_text;
@@ -8310,6 +8316,7 @@ struct llm_build_context {
83108316

83118317
inpSA = ggml_add(ctx0, inpSA, cur);
83128318
}
8319+
lctx.kv_cross.cache_filled = true;
83138320

83148321
cur = ggml_rms_norm(ctx0, inpSA, hparams.f_norm_rms_eps);
83158322
cur = ggml_mul(ctx0, cur, model.output_norm);

0 commit comments

Comments
 (0)