Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,12 @@ int llama_context::decode(llama_batch & inp_batch) {
const int32_t n_vocab = vocab.n_tokens();

const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
int64_t n_embd = hparams.n_embd;

if (model.arch == LLM_ARCH_QWEN3 || model.arch == LLM_ARCH_QWEN3MOE) {
// Qwen3 uses a different embedding size
n_embd = n_vocab;
}

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT

Expand Down Expand Up @@ -1067,7 +1072,15 @@ int llama_context::decode(llama_batch & inp_batch) {

if (n_outputs) {
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);

if (model.arch == LLM_ARCH_QWEN3 && cparams.embeddings) {
// For Qwen3 with embeddings enabled, we share the tensor between logits and embeddings
GGML_ASSERT(n_outputs * n_vocab <= (int64_t) logits_size);
} else {
// Standard check for other model architectures
GGML_ASSERT((n_outputs_prev + n_outputs) * n_vocab <= (int64_t) logits_size);
}

ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
}
}
Expand Down Expand Up @@ -1213,7 +1226,12 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {

const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
int64_t n_embd = hparams.n_embd;

// For Qwen3, n_embd is equal to n_vocab
if (model.arch == LLM_ARCH_QWEN3) {
n_embd = n_vocab;
}

// TODO: use a per-batch flag for logits presence instead
bool has_logits = !cparams.embeddings;
Expand All @@ -1225,8 +1243,19 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
has_embd = true;
}

logits_size = has_logits ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd*n_outputs_max : 0;
// For Qwen3 models, both logits and embeddings point to the same tensor
bool shared_tensor = (model.arch == LLM_ARCH_QWEN3);

// Adjust buffer sizes for the case where both tensors are shared
if (shared_tensor && has_logits && has_embd) {
// For Qwen3, we only need one buffer since logits and embeddings share the same tensor
logits_size = n_vocab * n_outputs_max;
embd_size = 0; // No need for a separate embedding buffer
} else {
// Normal case - separate buffers
logits_size = has_logits ? n_vocab * n_outputs_max : 0;
embd_size = has_embd ? n_embd * n_outputs_max : 0;
}

if (output_ids.empty()) {
// init, never resized afterwards
Expand Down
4 changes: 2 additions & 2 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7077,13 +7077,13 @@ struct llm_build_qwen3 : public llm_graph_context {
LLM_NORM_RMS, -1);

cb(cur, "result_norm", -1);
res->t_embd = cur;

// lm_head
cur = build_lora_mm(model.output, cur);

cb(cur, "result_output", -1);
res->t_logits = cur;
res->t_embd = cur;

ggml_build_forward_expand(gf, cur);
}
Expand Down Expand Up @@ -7205,13 +7205,13 @@ struct llm_build_qwen3moe : public llm_graph_context {
LLM_NORM_RMS, -1);

cb(cur, "result_norm", -1);
res->t_embd = cur;

// lm_head
cur = build_lora_mm(model.output, cur);

cb(cur, "result_output", -1);
res->t_logits = cur;
res->t_embd = cur;

ggml_build_forward_expand(gf, cur);
}
Expand Down
14 changes: 13 additions & 1 deletion tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2540,7 +2540,18 @@ struct server_context {
res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat;

const int n_embd = llama_model_n_embd(model);
int n_embd = llama_model_n_embd(model);
// For Qwen3 specific handling
bool is_qwen3 = false;
char arch_name[128] = {0};
if (llama_model_meta_val_str(model, "general.architecture", arch_name, sizeof(arch_name)) > 0) {
is_qwen3 = (strcmp(arch_name, "qwen3") == 0 || strcmp(arch_name, "qwen3moe") == 0);
if (is_qwen3) {
// Get vocabulary size for Qwen3 models - they use n_vocab as embedding size
n_embd = llama_vocab_n_tokens(vocab);
SLT_INF(slot, "Qwen3 model embedding size: %d\n", n_embd);
}
}

std::vector<float> embd_res(n_embd, 0.0f);

Expand All @@ -2551,6 +2562,7 @@ struct server_context {

const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
fprintf(stderr, "Failed to get embeddings\n");
embd = llama_get_embeddings_ith(ctx, i);
}

Expand Down
Loading