@@ -153,6 +153,7 @@ enum llm_arch {
153153 LLM_ARCH_QWEN,
154154 LLM_ARCH_QWEN2,
155155 LLM_ARCH_QWEN2MOE,
156+ LLM_ARCH_QWEN2VL,
156157 LLM_ARCH_QWEN3,
157158 LLM_ARCH_QWEN3MOE,
158159 LLM_ARCH_PHI2,
@@ -205,6 +206,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
205206 { LLM_ARCH_QWEN, "qwen" },
206207 { LLM_ARCH_QWEN2, "qwen2" },
207208 { LLM_ARCH_QWEN2MOE, "qwen2moe" },
209+ { LLM_ARCH_QWEN2VL, "qwen2vl" },
208210 { LLM_ARCH_QWEN3, "qwen3" },
209211 { LLM_ARCH_QWEN3MOE, "qwen3moe" },
210212 { LLM_ARCH_PHI2, "phi2" },
@@ -298,6 +300,7 @@ enum llm_kv {
298300 LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
299301 LLM_KV_ROPE_SCALING_FINETUNED,
300302 LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
303+ LLM_KV_ROPE_DIMENSION_SECTIONS,
301304
302305 LLM_KV_SPLIT_NO,
303306 LLM_KV_SPLIT_COUNT,
@@ -399,6 +402,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
399402 { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
400403 { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
401404 { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
405+ { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
402406
403407 { LLM_KV_SPLIT_NO, "split.no" },
404408 { LLM_KV_SPLIT_COUNT, "split.count" },
@@ -465,6 +469,10 @@ enum llm_tensor {
465469 LLM_TENSOR_ATTN_V,
466470 LLM_TENSOR_ATTN_QKV,
467471 LLM_TENSOR_ATTN_OUT,
472+ LLM_TENSOR_ATTN_Q_BIAS,
473+ LLM_TENSOR_ATTN_K_BIAS,
474+ LLM_TENSOR_ATTN_V_BIAS,
475+ LLM_TENSOR_ATTN_OUT_BIAS,
468476 LLM_TENSOR_ATTN_NORM,
469477 LLM_TENSOR_ATTN_NORM_2,
470478 LLM_TENSOR_ATTN_OUT_NORM,
@@ -848,6 +856,27 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
848856 { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
849857 },
850858 },
859+ {
860+ LLM_ARCH_QWEN2VL,
861+ {
862+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
863+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
864+ { LLM_TENSOR_OUTPUT, "output" },
865+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
866+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
867+ { LLM_TENSOR_ATTN_Q_BIAS, "blk.%d.attn_q_bias" },
868+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
869+ { LLM_TENSOR_ATTN_K_BIAS, "blk.%d.attn_k_bias" },
870+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
871+ { LLM_TENSOR_ATTN_V_BIAS, "blk.%d.attn_v_bias" },
872+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
873+ { LLM_TENSOR_ATTN_OUT_BIAS, "blk.%d.attn_output_bias" },
874+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
875+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
876+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
877+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
878+ },
879+ },
851880 {
852881 LLM_ARCH_QWEN3,
853882 {
@@ -1973,6 +2002,7 @@ enum e_model {
19732002 MODEL_40B,
19742003 MODEL_65B,
19752004 MODEL_70B,
2005+ MODEL_72B,
19762006 MODEL_236B,
19772007 MODEL_314B,
19782008 MODEL_SMALL,
@@ -2038,6 +2068,9 @@ struct llama_hparams {
20382068 float rope_freq_scale_train_swa;
20392069 uint32_t n_ctx_orig_yarn;
20402070 float rope_yarn_log_mul;
2071+
2072+ // for qwen2vl - rope dimension sections
2073+ std::vector<int32_t> rope_sections;
20412074
20422075 // for State Space Models
20432076 uint32_t ssm_d_conv = 0;
@@ -4411,6 +4444,7 @@ static const char * llama_model_type_name(e_model type) {
44114444 case MODEL_40B: return "40B";
44124445 case MODEL_65B: return "65B";
44134446 case MODEL_70B: return "70B";
4447+ case MODEL_72B: return "72B";
44144448 case MODEL_236B: return "236B";
44154449 case MODEL_314B: return "314B";
44164450 case MODEL_SMALL: return "0.1B";
@@ -4768,6 +4802,31 @@ static void llm_load_hparams(
47684802 default: model.type = e_model::MODEL_UNKNOWN;
47694803 }
47704804 } break;
4805+ case LLM_ARCH_QWEN2VL:
4806+ {
4807+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4808+
4809+ // Try to load rope dimension sections (optional for qwen2vl)
4810+ try {
4811+ int key_idx = gguf_find_key(ml.meta, llm_kv(LLM_KV_ROPE_DIMENSION_SECTIONS).c_str());
4812+ if (key_idx >= 0) {
4813+ auto arr_info = GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ml.meta, key_idx);
4814+ if (arr_info.gt == GGUF_TYPE_INT32 && arr_info.length == 4) {
4815+ hparams.rope_sections.resize(4);
4816+ memcpy(hparams.rope_sections.data(), arr_info.data, 4 * sizeof(int32_t));
4817+ }
4818+ }
4819+ } catch (...) {
4820+ // rope_sections are optional - ignore errors
4821+ }
4822+
4823+ switch (hparams.n_layer) {
4824+ case 32: model.type = e_model::MODEL_2B; break;
4825+ case 40: model.type = e_model::MODEL_7B; break;
4826+ case 80: model.type = e_model::MODEL_72B; break;
4827+ default: model.type = e_model::MODEL_UNKNOWN;
4828+ }
4829+ } break;
47714830 case LLM_ARCH_QWEN3:
47724831 {
47734832 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -6691,6 +6750,46 @@ static bool llm_load_tensors(
66916750 layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp});
66926751 }
66936752 } break;
6753+ case LLM_ARCH_QWEN2VL:
6754+ {
6755+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
6756+
6757+ // output
6758+ {
6759+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
6760+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
6761+ // if output is NULL, init from the input tok embed
6762+ if (model.output == NULL) {
6763+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
6764+ }
6765+ }
6766+
6767+ for (int i = 0; i < n_layer; ++i) {
6768+ ggml_context * ctx_layer = ctx_for_layer(i);
6769+ ggml_context * ctx_split = ctx_for_layer_split(i);
6770+
6771+ auto & layer = model.layers[i];
6772+
6773+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
6774+
6775+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
6776+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
6777+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
6778+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
6779+
6780+ // bias tensors for qwen2vl
6781+ layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
6782+ layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
6783+ layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
6784+ layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
6785+
6786+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
6787+
6788+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
6789+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
6790+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
6791+ }
6792+ } break;
66946793 case LLM_ARCH_QWEN3:
66956794 {
66966795 model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -10898,6 +10997,121 @@ struct llm_build_context {
1089810997 return gf;
1089910998 }
1090010999
11000+ struct ggml_cgraph * build_qwen2vl() {
11001+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
11002+
11003+ const int64_t n_embd_head = hparams.n_embd_head_v;
11004+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
11005+ GGML_ASSERT(n_embd_head == hparams.n_rot);
11006+
11007+ struct ggml_tensor * cur;
11008+ struct ggml_tensor * inpL;
11009+
11010+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
11011+
11012+ // inp_pos - contains the positions
11013+ struct ggml_tensor * inp_pos = build_inp_pos();
11014+
11015+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
11016+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
11017+
11018+ for (int il = 0; il < n_layer; ++il) {
11019+ struct ggml_tensor * inpSA = inpL;
11020+
11021+ // norm
11022+ cur = llm_build_norm(ctx0, inpL, hparams,
11023+ model.layers[il].attn_norm, NULL,
11024+ LLM_NORM_RMS, cb, il);
11025+ cb(cur, "attn_norm", il);
11026+
11027+ // self-attention
11028+ {
11029+ // compute Q and K and RoPE them
11030+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
11031+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
11032+ cb(Qcur, "Qcur", il);
11033+
11034+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
11035+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
11036+ cb(Kcur, "Kcur", il);
11037+
11038+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
11039+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
11040+ cb(Vcur, "Vcur", il);
11041+
11042+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
11043+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
11044+
11045+ // Apply rope - qwen2vl uses standard rope for now
11046+ // TODO: Implement rope_multi with sections when available in llamafile
11047+ Qcur = ggml_rope_ext(
11048+ ctx0, Qcur, inp_pos, nullptr,
11049+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11050+ ext_factor, attn_factor, beta_fast, beta_slow
11051+ );
11052+
11053+ Kcur = ggml_rope_ext(
11054+ ctx0, Kcur, inp_pos, nullptr,
11055+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11056+ ext_factor, attn_factor, beta_fast, beta_slow
11057+ );
11058+
11059+ cb(Qcur, "Qcur", il);
11060+ cb(Kcur, "Kcur", il);
11061+
11062+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
11063+ model.layers[il].wo, model.layers[il].bo,
11064+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
11065+ }
11066+
11067+ if (il == n_layer - 1) {
11068+ // skip computing output for unused tokens
11069+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
11070+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11071+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11072+ }
11073+
11074+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
11075+ cb(ffn_inp, "ffn_inp", il);
11076+
11077+ // feed-forward network
11078+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
11079+ model.layers[il].ffn_norm, NULL,
11080+ LLM_NORM_RMS, cb, il);
11081+ cb(cur, "ffn_norm", il);
11082+
11083+ cur = llm_build_ffn(ctx0, lctx, cur,
11084+ model.layers[il].ffn_up, NULL, NULL,
11085+ model.layers[il].ffn_gate, NULL, NULL,
11086+ model.layers[il].ffn_down, NULL, NULL,
11087+ NULL,
11088+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
11089+ cb(cur, "ffn_out", il);
11090+
11091+ cur = ggml_add(ctx0, cur, ffn_inp);
11092+ cur = lctx.cvec.apply_to(ctx0, cur, il);
11093+ cb(cur, "l_out", il);
11094+
11095+ // input for next layer
11096+ inpL = cur;
11097+ }
11098+
11099+ cur = inpL;
11100+
11101+ cur = llm_build_norm(ctx0, cur, hparams,
11102+ model.output_norm, NULL,
11103+ LLM_NORM_RMS, cb, -1);
11104+ cb(cur, "result_norm", -1);
11105+
11106+ // lm_head
11107+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
11108+ cb(cur, "result_output", -1);
11109+
11110+ ggml_build_forward_expand(gf, cur);
11111+
11112+ return gf;
11113+ }
11114+
1090111115 struct ggml_cgraph * build_qwen3() {
1090211116 struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
1090311117
@@ -14736,6 +14950,10 @@ static struct ggml_cgraph * llama_build_graph(
1473614950 {
1473714951 result = llm.build_qwen2moe();
1473814952 } break;
14953+ case LLM_ARCH_QWEN2VL:
14954+ {
14955+ result = llm.build_qwen2vl();
14956+ } break;
1473914957 case LLM_ARCH_QWEN3:
1474014958 {
1474114959 result = llm.build_qwen3();
@@ -17963,6 +18181,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1796318181 case LLM_ARCH_QWEN:
1796418182 case LLM_ARCH_QWEN2:
1796518183 case LLM_ARCH_QWEN2MOE:
18184+ case LLM_ARCH_QWEN2VL:
1796618185 case LLM_ARCH_QWEN3:
1796718186 case LLM_ARCH_QWEN3MOE:
1796818187 case LLM_ARCH_PHI2:
0 commit comments