Skip to content

Commit 4d1ff87

Browse files
committed
qwen3 get embedding from logits
1 parent e562eec commit 4d1ff87

File tree

3 files changed

+49
-8
lines changed

3 files changed

+49
-8
lines changed

src/llama-context.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,12 @@ int llama_context::decode(llama_batch & inp_batch) {
887887
const int32_t n_vocab = vocab.n_tokens();
888888

889889
const int64_t n_tokens_all = batch.n_tokens;
890-
const int64_t n_embd = hparams.n_embd;
890+
int64_t n_embd = hparams.n_embd;
891+
892+
if (model.arch == LLM_ARCH_QWEN3 || model.arch == LLM_ARCH_QWEN3MOE) {
893+
// Qwen3 uses a different embedding size
894+
n_embd = n_vocab;
895+
}
891896

892897
llama_kv_cache_guard kv_guard(kv_self);
893898

@@ -1021,7 +1026,15 @@ int llama_context::decode(llama_batch & inp_batch) {
10211026

10221027
if (n_outputs) {
10231028
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1024-
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1029+
1030+
if (model.arch == LLM_ARCH_QWEN3 && cparams.embeddings) {
1031+
// For Qwen3 with embeddings enabled, we share the tensor between logits and embeddings
1032+
GGML_ASSERT(n_outputs * n_vocab <= (int64_t) logits_size);
1033+
} else {
1034+
// Standard check for other model architectures
1035+
GGML_ASSERT((n_outputs_prev + n_outputs) * n_vocab <= (int64_t) logits_size);
1036+
}
1037+
10251038
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
10261039
}
10271040
}
@@ -1170,7 +1183,12 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
11701183

11711184
const auto n_batch = cparams.n_batch;
11721185
const auto n_vocab = vocab.n_tokens();
1173-
const auto n_embd = hparams.n_embd;
1186+
int64_t n_embd = hparams.n_embd;
1187+
1188+
// For Qwen3, n_embd is equal to n_vocab
1189+
if (model.arch == LLM_ARCH_QWEN3) {
1190+
n_embd = n_vocab;
1191+
}
11741192

11751193
// TODO: use a per-batch flag for logits presence instead
11761194
bool has_logits = !cparams.embeddings;
@@ -1182,8 +1200,19 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
11821200
has_embd = true;
11831201
}
11841202

1185-
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1186-
embd_size = has_embd ? n_embd*n_outputs_max : 0;
1203+
// For Qwen3 models, both logits and embeddings point to the same tensor
1204+
bool shared_tensor = (model.arch == LLM_ARCH_QWEN3);
1205+
1206+
// Adjust buffer sizes for the case where both tensors are shared
1207+
if (shared_tensor && has_logits && has_embd) {
1208+
// For Qwen3, we only need one buffer since logits and embeddings share the same tensor
1209+
logits_size = n_vocab * n_outputs_max;
1210+
embd_size = 0; // No need for a separate embedding buffer
1211+
} else {
1212+
// Normal case - separate buffers
1213+
logits_size = has_logits ? n_vocab * n_outputs_max : 0;
1214+
embd_size = has_embd ? n_embd * n_outputs_max : 0;
1215+
}
11871216

11881217
if (output_ids.empty()) {
11891218
// init, never resized afterwards

src/llama-model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7074,13 +7074,13 @@ struct llm_build_qwen3 : public llm_graph_context {
70747074
LLM_NORM_RMS, -1);
70757075

70767076
cb(cur, "result_norm", -1);
7077-
res->t_embd = cur;
70787077

70797078
// lm_head
70807079
cur = build_lora_mm(model.output, cur);
70817080

70827081
cb(cur, "result_output", -1);
70837082
res->t_logits = cur;
7083+
res->t_embd = cur;
70847084

70857085
ggml_build_forward_expand(gf, cur);
70867086
}
@@ -7202,13 +7202,13 @@ struct llm_build_qwen3moe : public llm_graph_context {
72027202
LLM_NORM_RMS, -1);
72037203

72047204
cb(cur, "result_norm", -1);
7205-
res->t_embd = cur;
72067205

72077206
// lm_head
72087207
cur = build_lora_mm(model.output, cur);
72097208

72107209
cb(cur, "result_output", -1);
72117210
res->t_logits = cur;
7211+
res->t_embd = cur;
72127212

72137213
ggml_build_forward_expand(gf, cur);
72147214
}

tools/server/server.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2545,7 +2545,18 @@ struct server_context {
25452545
res->n_tokens = slot.n_prompt_tokens;
25462546
res->oaicompat = slot.params.oaicompat;
25472547

2548-
const int n_embd = llama_model_n_embd(model);
2548+
int n_embd = llama_model_n_embd(model);
2549+
// For Qwen3 specific handling
2550+
bool is_qwen3 = false;
2551+
char arch_name[128] = {0};
2552+
if (llama_model_meta_val_str(model, "general.architecture", arch_name, sizeof(arch_name)) > 0) {
2553+
is_qwen3 = (strcmp(arch_name, "qwen3") == 0 || strcmp(arch_name, "qwen3moe") == 0);
2554+
if (is_qwen3) {
2555+
// Get vocabulary size for Qwen3 models - they use n_vocab as embedding size
2556+
n_embd = llama_vocab_n_tokens(vocab);
2557+
SLT_INF(slot, "Qwen3 model embedding size: %d\n", n_embd);
2558+
}
2559+
}
25492560

25502561
std::vector<float> embd_res(n_embd, 0.0f);
25512562

@@ -2556,6 +2567,7 @@ struct server_context {
25562567

25572568
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
25582569
if (embd == NULL) {
2570+
fprintf(stderr, "Failed to get embeddings\n");
25592571
embd = llama_get_embeddings_ith(ctx, i);
25602572
}
25612573

0 commit comments

Comments
 (0)