Skip to content

Commit 0c2aad7

Browse files
authored
add n_embd_full to support extended embed
1 parent 7fd205a commit 0c2aad7

File tree

8 files changed

+24
-29
lines changed

8 files changed

+24
-29
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ extern "C" {
482482

483483
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
484484
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
485+
LLAMA_API int32_t llama_model_n_embd_full(const struct llama_model * model);
485486
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
486487
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
487488
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);

src/llama-context.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
620620
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
621621
}
622622

623-
return embd + j*model.hparams.n_embd;
623+
return embd + j*model.hparams.n_embd_full;
624624
} catch (const std::exception & err) {
625625
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
626626
#ifndef NDEBUG
@@ -808,7 +808,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
808808

809809
const auto & hparams = model.hparams;
810810

811-
const int64_t n_embd = hparams.n_embd;
811+
const int64_t n_embd = hparams.n_embd_full;
812812
const int64_t n_vocab = model.vocab.n_tokens();
813813

814814
// note: during encode, we always pass the full sequence starting from pos = 0
@@ -977,7 +977,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
977977
const auto & hparams = model.hparams;
978978

979979
const int64_t n_vocab = vocab.n_tokens();
980-
const int64_t n_embd = hparams.n_embd;
980+
const int64_t n_embd = hparams.n_embd_full;
981981

982982
// when computing embeddings, all tokens are output
983983
const bool output_all = cparams.embeddings;
@@ -1276,7 +1276,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12761276

12771277
const auto n_batch = cparams.n_batch;
12781278
const auto n_vocab = vocab.n_tokens();
1279-
const auto n_embd = hparams.n_embd;
1279+
const auto n_embd = hparams.n_embd_full;
12801280

12811281
bool has_logits = true;
12821282
bool has_embd = cparams.embeddings;
@@ -1340,7 +1340,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
13401340

13411341
void llama_context::output_reorder() {
13421342
const uint64_t n_vocab = model.vocab.n_tokens();
1343-
const uint64_t n_embd = model.hparams.n_embd;
1343+
const uint64_t n_embd = model.hparams.n_embd_full;
13441344

13451345
for (size_t s = 0; s < output_swaps.size(); ++s) {
13461346
const uint64_t i0 = output_swaps[s].i0;
@@ -1883,7 +1883,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
18831883
{
18841884
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
18851885

1886-
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
1886+
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd_full);
18871887

18881888
io.write(&embd_size, sizeof(embd_size));
18891889

@@ -2135,7 +2135,7 @@ void llama_context::opt_epoch_iter(
21352135
batch.logits [pos_batch] = true;
21362136
}
21372137

2138-
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2138+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_full, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
21392139
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
21402140
return;
21412141
}

src/llama-graph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
11421142

11431143
// input embeddings with optional lora
11441144
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1145-
const int64_t n_embd = hparams.n_embd;
1145+
const int64_t n_embd = hparams.n_embd_full;
11461146

11471147
auto inp = std::make_unique<llm_graph_input_embd>();
11481148

@@ -1279,7 +1279,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
12791279
// return cur;
12801280
//}
12811281

1282-
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
1282+
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_full;
12831283
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
12841284

12851285
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ struct llama_hparams {
4040

4141
uint32_t n_ctx_train; // context size the model was trained on
4242
uint32_t n_embd;
43+
uint32_t n_embd_full; // main + auxiliary embeds
4344
uint32_t n_embd_features = 0;
4445
uint32_t n_layer;
4546
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache

src/llama-model.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
276276
} break;
277277
case GGML_OP_IM2COL:
278278
{
279-
const int n_embd = hparams.n_embd;
279+
const int n_embd = hparams.n_embd_full;
280280
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
281281
op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
282282
} break;
@@ -505,6 +505,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
505505
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
506506
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
507507
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
508+
hparams.n_embd_full = hparams.n_embd;
508509

509510
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
510511
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
@@ -1041,7 +1042,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
10411042
}
10421043
// since vision model stacks deepstack features along feature dim
10431044
// we also create a fake "n_embd" for text model to be the main embd + deepstack embds
1044-
hparams.n_embd *= hparams.n_deepstack_layers + 1;
1045+
hparams.n_embd_full *= hparams.n_deepstack_layers + 1;
10451046
} break;
10461047
case LLM_ARCH_QWEN3MOE:
10471048
{
@@ -1067,7 +1068,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
10671068
}
10681069
// since vision model stacks deepstack features along feature dim
10691070
// we also create a fake "n_embd" for text model to be the main embd + deepstack embds
1070-
hparams.n_embd *= hparams.n_deepstack_layers + 1;
1071+
hparams.n_embd_full *= hparams.n_deepstack_layers + 1;
10711072
} break;
10721073
case LLM_ARCH_PHI2:
10731074
{
@@ -3332,10 +3333,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33323333
case LLM_ARCH_QWEN3:
33333334
case LLM_ARCH_QWEN3VL:
33343335
{
3335-
// for model loading, the weights only have the main embd
3336-
// so we need to divide by the number of deepstack layers + 1
3337-
// n_embd is const int so we declare a new variable
3338-
int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
33393336
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
33403337

33413338
// output
@@ -3371,10 +3368,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33713368
case LLM_ARCH_QWEN3MOE:
33723369
case LLM_ARCH_QWEN3VLMOE:
33733370
{
3374-
// for model loading, the weights only have the main embd
3375-
// so we need to divide by the number of deepstack layers + 1
3376-
// n_embd is const int so we declare a new variable
3377-
int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
33783371
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
33793372

33803373
// output
@@ -6681,8 +6674,8 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
66816674
return ::select_buft(
66826675
*pimpl->dev_layer.at(il).buft_list,
66836676
[&](ggml_context * ctx) {
6684-
ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
6685-
ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
6677+
ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_full);
6678+
ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_full);
66866679
return ggml_add(ctx, cur, layer_dir);
66876680
});
66886681
}
@@ -7329,6 +7322,10 @@ int32_t llama_model_n_embd(const llama_model * model) {
73297322
return model->hparams.n_embd;
73307323
}
73317324

7325+
int32_t llama_model_n_embd_full(const llama_model * model) {
7326+
return model->hparams.n_embd_full;
7327+
}
7328+
73327329
int32_t llama_model_n_layer(const llama_model * model) {
73337330
return model->hparams.n_layer;
73347331
}

src/models/qwen3vl-moe.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#include "models.h"
22

33
llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
4-
const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds
54
const size_t n_deepstack_layers = hparams.n_deepstack_layers;
6-
const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1);
5+
const int64_t n_embd = hparams.n_embd;
76
const int64_t n_embd_head = hparams.n_embd_head_v;
87

98
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);

src/models/qwen3vl.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
#include "models.h"
22

33
llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
4-
5-
const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds
64
const size_t n_deepstack_layers = hparams.n_deepstack_layers;
7-
const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1);
5+
const int64_t n_embd = hparams.n_embd;
86
const int64_t n_embd_head = hparams.n_embd_head_v;
97

10-
118
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
129
GGML_ASSERT(n_embd_head == hparams.n_rot);
1310

tools/mtmd/mtmd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ struct mtmd_context {
151151
print_timings(ctx_params.print_timings),
152152
n_threads (ctx_params.n_threads),
153153
media_marker (ctx_params.media_marker),
154-
n_embd_text (llama_model_n_embd(text_model))
154+
n_embd_text (llama_model_n_embd_full(text_model))
155155
{
156156
if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
157157
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");

0 commit comments

Comments
 (0)