@@ -772,6 +772,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
772772 default: type = LLM_TYPE_UNKNOWN;
773773 }
774774 } break;
775+ case LLM_ARCH_QWEN3:
776+ {
777+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
778+ switch (hparams.n_layer) {
779+ default: type = LLM_TYPE_UNKNOWN;
780+ }
781+ } break;
775782 case LLM_ARCH_PHI2:
776783 {
777784 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2254,6 +2261,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22542261 layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
22552262 }
22562263 } break;
2264+ case LLM_ARCH_QWEN3:
2265+ {
2266+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2267+
2268+ // output
2269+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
2270+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2271+ // if output is NULL, init from the input tok embed
2272+ if (output == NULL) {
2273+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2274+ }
2275+
2276+ for (int i = 0; i < n_layer; ++i) {
2277+ auto & layer = layers[i];
2278+
2279+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
2280+
2281+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
2282+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
2283+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
2284+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
2285+
2286+ // optional bias tensors
2287+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
2288+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
2289+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
2290+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
2291+
2292+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
2293+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
2294+
2295+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
2296+
2297+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
2298+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
2299+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
2300+ }
2301+ } break;
22572302 case LLM_ARCH_PHI2:
22582303 {
22592304 tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -6335,6 +6380,139 @@ struct llm_build_qwen2moe : public llm_graph_context {
63356380 }
63366381};
63376382
6383+ struct llm_build_qwen3 : public llm_graph_context {
6384+ llm_build_qwen3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6385+ const int64_t n_embd_head = hparams.n_embd_head_v;
6386+
6387+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
6388+ GGML_ASSERT(n_embd_head == hparams.n_rot);
6389+
6390+ ggml_tensor * cur;
6391+ ggml_tensor * inpL;
6392+
6393+ inpL = build_inp_embd(model.tok_embd);
6394+
6395+ // inp_pos - contains the positions
6396+ ggml_tensor * inp_pos = build_inp_pos();
6397+
6398+ auto * inp_attn = build_attn_inp_kv_unified();
6399+
6400+ for (int il = 0; il < n_layer; ++il) {
6401+ ggml_tensor * inpSA = inpL;
6402+
6403+ // norm
6404+ cur = build_norm(inpL,
6405+ model.layers[il].attn_norm, NULL,
6406+ LLM_NORM_RMS, il);
6407+ cb(cur, "attn_norm", il);
6408+
6409+ // self-attention
6410+ {
6411+ // compute Q and K and RoPE them
6412+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
6413+ cb(Qcur, "Qcur", il);
6414+ if (model.layers[il].bq) {
6415+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
6416+ cb(Qcur, "Qcur", il);
6417+ }
6418+
6419+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
6420+ cb(Kcur, "Kcur", il);
6421+ if (model.layers[il].bk) {
6422+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
6423+ cb(Kcur, "Kcur", il);
6424+ }
6425+
6426+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
6427+ cb(Vcur, "Vcur", il);
6428+ if (model.layers[il].bv) {
6429+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
6430+ cb(Vcur, "Vcur", il);
6431+ }
6432+
6433+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6434+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6435+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6436+
6437+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
6438+ cb(Qcur, "Qcur_normed", il);
6439+
6440+ Qcur = ggml_rope_ext(
6441+ ctx0, Qcur, inp_pos, nullptr,
6442+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6443+ ext_factor, attn_factor, beta_fast, beta_slow
6444+ );
6445+
6446+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
6447+ cb(Kcur, "Kcur_normed", il);
6448+
6449+ Kcur = ggml_rope_ext(
6450+ ctx0, Kcur, inp_pos, nullptr,
6451+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6452+ ext_factor, attn_factor, beta_fast, beta_slow
6453+ );
6454+
6455+ cb(Qcur, "Qcur", il);
6456+ cb(Kcur, "Kcur", il);
6457+ cb(Vcur, "Vcur", il);
6458+
6459+ cur = build_attn(inp_attn, gf,
6460+ model.layers[il].wo, model.layers[il].bo,
6461+ Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6462+ }
6463+
6464+ if (il == n_layer - 1) {
6465+ // skip computing output for unused tokens
6466+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6467+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6468+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
6469+ }
6470+
6471+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
6472+ cb(ffn_inp, "ffn_inp", il);
6473+
6474+ // feed-forward network
6475+ cur = build_norm(ffn_inp,
6476+ model.layers[il].ffn_norm, NULL,
6477+ LLM_NORM_RMS, il);
6478+ cb(cur, "ffn_norm", il);
6479+
6480+ cur = build_ffn(cur,
6481+ model.layers[il].ffn_up, NULL, NULL,
6482+ model.layers[il].ffn_gate, NULL, NULL,
6483+ model.layers[il].ffn_down, NULL, NULL,
6484+ NULL,
6485+ LLM_FFN_SILU, LLM_FFN_PAR, il);
6486+ cb(cur, "ffn_out", il);
6487+
6488+ cur = ggml_add(ctx0, cur, ffn_inp);
6489+
6490+ cur = build_cvec(cur, il);
6491+ cb(cur, "l_out", il);
6492+
6493+ // input for next layer
6494+ inpL = cur;
6495+ }
6496+
6497+ cur = inpL;
6498+
6499+ cur = build_norm(cur,
6500+ model.output_norm, NULL,
6501+ LLM_NORM_RMS, -1);
6502+
6503+ cb(cur, "result_norm", -1);
6504+ res->t_embd = cur;
6505+
6506+ // lm_head
6507+ cur = build_lora_mm(model.output, cur);
6508+
6509+ cb(cur, "result_output", -1);
6510+ res->t_logits = cur;
6511+
6512+ ggml_build_forward_expand(gf, cur);
6513+ }
6514+ };
6515+
63386516struct llm_build_phi2 : public llm_graph_context {
63396517 llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
63406518 const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -11718,6 +11896,10 @@ llm_graph_result_ptr llama_model::build_graph(
1171811896 {
1171911897 llm = std::make_unique<llm_build_qwen2moe>(*this, params, gf);
1172011898 } break;
11899+ case LLM_ARCH_QWEN3:
11900+ {
11901+ llm = std::make_unique<llm_build_qwen3>(*this, params, gf);
11902+ } break;
1172111903 case LLM_ARCH_PHI2:
1172211904 {
1172311905 llm = std::make_unique<llm_build_phi2>(*this, params, gf);
@@ -12024,6 +12206,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1202412206 case LLM_ARCH_QWEN:
1202512207 case LLM_ARCH_QWEN2:
1202612208 case LLM_ARCH_QWEN2MOE:
12209+ case LLM_ARCH_QWEN3:
1202712210 case LLM_ARCH_OLMO2:
1202812211 case LLM_ARCH_OLMOE:
1202912212 case LLM_ARCH_PHI2:
0 commit comments