diff --git a/examples/mtmd/clip-impl.h b/examples/mtmd/clip-impl.h index fb765a4fe..b169485c2 100644 --- a/examples/mtmd/clip-impl.h +++ b/examples/mtmd/clip-impl.h @@ -36,6 +36,7 @@ #define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" +#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -59,6 +60,7 @@ #define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat #define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" #define TN_PATCH_BIAS "v.patch_embd.bias" +#define TN_ATTN_QKV "%s.blk.%d.attn_qkv.%s" #define TN_ATTN_K "%s.blk.%d.attn_k.%s" #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s" @@ -89,6 +91,9 @@ #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) #define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model) +#define TN_DEEPSTACK_NORM "v.deepstack.%d.norm.%s" // qwen3vl deepstack +#define TN_DEEPSTACK_FC1 "v.deepstack.%d.fc1.%s" // qwen3vl deepstack +#define TN_DEEPSTACK_FC2 "v.deepstack.%d.fc2.%s" // qwen3vl deepstack // mimicpmv #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" @@ -123,6 +128,7 @@ enum projector_type { PROJECTOR_TYPE_MINICPMV, PROJECTOR_TYPE_GLM_EDGE, PROJECTOR_TYPE_QWEN2VL, + PROJECTOR_TYPE_QWEN3VL, PROJECTOR_TYPE_GEMMA3, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, @@ -146,6 +152,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, { PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"}, { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"}, + { PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, diff --git a/examples/mtmd/clip.cpp b/examples/mtmd/clip.cpp index 0f251ed5e..b04d0846d 100644 --- a/examples/mtmd/clip.cpp +++ b/examples/mtmd/clip.cpp @@ -204,6 +204,7 @@ struct clip_hparams { int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; int32_t spatial_merge_size = 0; + std::vector deepstack_layers; // qwen3vl multi-level feature fusion // audio int32_t n_mel_bins = 0; // whisper preprocessor @@ -223,6 +224,8 @@ struct clip_layer { ggml_tensor * q_b = nullptr; ggml_tensor * v_w = nullptr; ggml_tensor * v_b = nullptr; + ggml_tensor * qkv_w = nullptr; + ggml_tensor * qkv_b = nullptr; ggml_tensor * o_w = nullptr; ggml_tensor * o_b = nullptr; @@ -248,6 +251,18 @@ struct clip_layer { // layer scale (no bias) ggml_tensor * ls_1_w = nullptr; ggml_tensor * ls_2_w = nullptr; + + // qwen3vl deepstack merger + ggml_tensor * deepstack_norm_w = nullptr; + ggml_tensor * deepstack_norm_b = nullptr; + ggml_tensor * deepstack_fc1_w = nullptr; + ggml_tensor * deepstack_fc1_b = nullptr; + ggml_tensor * deepstack_fc2_w = nullptr; + ggml_tensor * deepstack_fc2_b = nullptr; + + bool has_deepstack() const { + return deepstack_fc1_w != nullptr; + } }; struct clip_model { @@ -267,6 +282,8 @@ struct clip_model { std::vector layers; + int32_t n_deepstack_layers = 0; // used by Qwen3-VL, calculated from clip_layer + ggml_tensor * post_ln_w; ggml_tensor * post_ln_b; @@ -847,6 +864,189 @@ struct clip_graph { return gf; } + // Qwen3VL + ggml_cgraph * build_qwen3vl() { + GGML_ASSERT(model.patch_bias != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + + const int batch_size = 1; + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position + + norm_type norm_t = NORM_TYPE_NORMAL; + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + // second conv dimension + { + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + n_embd, n_patches_x * n_patches_y, batch_size); + } + + // add patch bias + if (model.patch_bias != nullptr) { + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + } + + // calculate absolute position embedding and apply + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + learned_pos_embd = ggml_cont_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + learned_pos_embd = ggml_reshape_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3); + learned_pos_embd = ggml_cont_3d( + ctx0, learned_pos_embd, + n_embd, n_patches_x * n_patches_y, batch_size); + inp = ggml_add(ctx0, inp, learned_pos_embd); + cb(inp, "inp_pos_emb", -1); + + ggml_tensor * inpL = inp; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + } + + // deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size] + ggml_tensor * deepstack_features = nullptr; + const int merge_factor = hparams.spatial_merge_size > 0 ? hparams.spatial_merge_size * hparams.spatial_merge_size : 4; // default 2x2=4 for qwen3vl + + // loop over layers + for (int il = 0; il < n_layer; il++) { + auto & layer = model.layers[il]; + + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + cb(cur, "ln1", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + cur = ggml_add(ctx0, cur, layer.qkv_b); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), + cur->nb[1], 0); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), + cur->nb[1], n_embd * sizeof(float)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), + cur->nb[1], 2 * n_embd * sizeof(float)); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // apply M-RoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + cb(cur, "ffn_inp", il); + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + cb(cur, "ffn_inp_normed", il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + if (layer.has_deepstack()) { + ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size); + feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il); + feat = build_ffn(feat, + layer.deepstack_fc1_w, layer.deepstack_fc1_b, + nullptr, nullptr, + layer.deepstack_fc2_w, layer.deepstack_fc2_b, + ffn_op_type::FFN_GELU, il); + + if(!deepstack_features) { + deepstack_features = feat; + } else { + // concat along the feature dimension + deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0); + } + } + + inpL = cur; + } + + // post-layernorm + if (model.post_ln_w) { + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); + } + + // multimodal projection + ggml_tensor * embeddings = inpL; + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + ffn_op_type::FFN_GELU, -1); + + embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; + } + ggml_cgraph * build_minicpmv() { const int batch_size = 1; @@ -2119,6 +2319,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_qwen2vl(); } break; + case PROJECTOR_TYPE_QWEN3VL: + { + res = graph.build_qwen3vl(); + } break; case PROJECTOR_TYPE_MINICPMV: { res = graph.build_minicpmv(); @@ -2424,6 +2628,12 @@ struct clip_model_loader { hparams.warmup_image_size = hparams.patch_size * 8; get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern); } break; + case PROJECTOR_TYPE_QWEN3VL: + { + hparams.image_size = 1024; // still need this? + hparams.warmup_image_size = hparams.patch_size * 8; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); + } break; case PROJECTOR_TYPE_LLAMA4: { hparams.rope_theta = 10000.0f; @@ -2462,6 +2672,9 @@ struct clip_model_loader { LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor); LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + if (hparams.spatial_merge_size > 0) { + LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size); + } } else if (is_audio) { LOG_INF("\n--- audio hparams ---\n"); LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins); @@ -2533,9 +2746,16 @@ struct clip_model_loader { model.layers.resize(hparams.n_layer); for (int il = 0; il < hparams.n_layer; ++il) { auto & layer = model.layers[il]; - layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight")); - layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight")); - layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight")); + // try combined qkv weight first; if absent, require separate k/q/v weights + layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false); + if (!layer.qkv_w) { + // combined not present => require separate tensors (no 'false' argument because tensors always required) + layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight")); + layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight")); + layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight")); + } + + // other attention tensors (output / norms / ln) left as-is layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight")); layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false); layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false); @@ -2544,9 +2764,16 @@ struct clip_model_loader { layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias - layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false); - layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false); - layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false); + // try combined qkv bias first; if absent, require separate k/q/v biases + layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false); + if (!layer.qkv_b) { + // combined not present => require separate biases ('false' because tensors not required) + layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false); + layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false); + layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false); + } + + // keep other optional biases as before layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false); layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false); layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false); @@ -2559,6 +2786,18 @@ struct clip_model_loader { layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); + + // qwen3vl deepstack layer + layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false); + layer.deepstack_norm_b = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "bias"), false); + layer.deepstack_fc1_w = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "weight"), false); + layer.deepstack_fc1_b = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "bias"), false); + layer.deepstack_fc2_w = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "weight"), false); + layer.deepstack_fc2_b = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "bias"), false); + if (layer.has_deepstack()) { + model.n_deepstack_layers++; + } + // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check! bool is_ffn_swapped = ( @@ -2694,6 +2933,13 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_QWEN3VL: + { + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; case PROJECTOR_TYPE_GEMMA3: { model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); @@ -3557,7 +3803,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_y = inst.grid_size.height; return true; - } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { clip_image_u8 resized; auto patch_size = params.patch_size * 2; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size); @@ -3736,7 +3982,7 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) { int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; const int n_total = clip_n_output_tokens(ctx, img); - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0); } return n_total; @@ -3744,7 +3990,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0); } return 1; @@ -3800,6 +4046,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_QWEN3VL: { // dynamic size (2 conv, so double patch size) int patch_size = params.patch_size * 2; @@ -4104,6 +4351,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_f32("pos_embed", pos_embed); } break; case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN3VL: { const int merge_ratio = 2; const int pw = image_size_width / patch_size; @@ -4354,6 +4602,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: return ctx->model.mm_1_b->ne[0]; + case PROJECTOR_TYPE_QWEN3VL: + // main path + deepstack paths + return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers); case PROJECTOR_TYPE_GEMMA3: return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: @@ -4388,7 +4639,8 @@ bool clip_is_glm(const struct clip_ctx * ctx) { bool clip_is_qwen2vl(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL - || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL; + || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL + || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL; } bool clip_is_llava(const struct clip_ctx * ctx) { diff --git a/examples/mtmd/clip.h b/examples/mtmd/clip.h index 3387cdbd3..bb3066d06 100644 --- a/examples/mtmd/clip.h +++ b/examples/mtmd/clip.h @@ -93,6 +93,7 @@ bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct int clip_is_minicpmv(const struct clip_ctx * ctx); bool clip_is_glm(const struct clip_ctx * ctx); bool clip_is_qwen2vl(const struct clip_ctx * ctx); +bool clip_is_qwen3vl(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); diff --git a/examples/mtmd/mtmd.cpp b/examples/mtmd/mtmd.cpp index 94322e141..ac3aa20b7 100644 --- a/examples/mtmd/mtmd.cpp +++ b/examples/mtmd/mtmd.cpp @@ -252,7 +252,7 @@ struct mtmd_context { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md img_end = "[IMG_END]"; - } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL) { + } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; img_end = "<|vision_end|>"; diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d2cb164c2..a97743efd 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -259,6 +259,7 @@ #define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 +#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000 #define GGML_MROPE_SECTIONS 4 diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index d058504cd..b5a27684e 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -125,7 +125,7 @@ template static __global__ void rope_multi( const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -152,17 +152,27 @@ static __global__ void rope_multi( const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; - if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + } + } else { + if (sector < sections.v[0]) { + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + } } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -276,7 +286,7 @@ template static void rope_multi_cuda( const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -287,11 +297,11 @@ static void rope_multi_cuda( if (freq_factors == nullptr) { rope_multi<<>>( x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, sections); + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { rope_multi<<>>( x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, sections); + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } } @@ -369,6 +379,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -406,11 +417,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) if (src0->type == GGML_TYPE_F32) { rope_multi_cuda( (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { rope_multi_cuda( (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 778033d93..c880898cc 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -815,6 +815,7 @@ struct vk_op_rope_push_constants { uint32_t s1; uint32_t s2; int32_t sections[4]; + uint32_t is_imrope; uint32_t is_back; }; @@ -6754,6 +6755,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const const int mode = ((const int32_t *) dst->op_params)[2]; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_neox) { @@ -6763,7 +6765,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_rope_neox_f16; } - } else if (is_mrope && !is_vision) { + } else if ((is_mrope || is_imrope) && !is_vision) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rope_multi_f32; } @@ -7970,6 +7972,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); } + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); @@ -7982,7 +7986,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, - sections[0], sections[1], sections[2], sections[3], backprop + sections[0], sections[1], sections[2], sections[3], is_imrope, backprop }, dryrun); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7d4c4feb1..3eebc1cc8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -18339,7 +18339,7 @@ static void ggml_rope_cache_init( } static void ggml_mrope_cache_init( - float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects, + float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float * cache, float sin_sign, float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py @@ -18374,14 +18374,26 @@ static void ggml_mrope_cache_init( } float theta = theta_t; - if (sector >= sections[0] && sector < sec_w) { - theta = theta_h; - } - else if (sector >= sec_w && sector < sec_w + sections[2]) { - theta = theta_w; - } - else if (sector >= sec_w + sections[2]) { - theta = theta_e; + if (is_imrope) { // qwen3vl apply interleaved mrope + if (sector % 3 == 1 && sector < 3 * sections[1]) { + theta = theta_h; + } else if (sector % 3 == 2 && sector < 3 * sections[2]) { + theta = theta_w; + } else if (sector % 3 == 0 && sector < 3 * sections[0]) { + theta = theta_t; + } else { + theta = theta_e; + } + } else { + if (sector >= sections[0] && sector < sec_w) { + theta = theta_h; + } + else if (sector >= sec_w && sector < sec_w + sections[2]) { + theta = theta_w; + } + else if (sector >= sec_w + sections[2]) { + theta = theta_e; + } } rope_yarn( @@ -18454,6 +18466,7 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -18492,7 +18505,7 @@ static void ggml_compute_forward_rope_f32( const int64_t p_w = pos[i2 + ne2 * 2]; const int64_t p_e = pos[i2 + ne2 * 3]; ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_vision, + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); } @@ -18640,6 +18653,7 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -18678,7 +18692,7 @@ static void ggml_compute_forward_rope_f16( const int64_t p_w = pos[i2 + ne2 * 2]; const int64_t p_e = pos[i2 + ne2 * 3]; ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_vision, + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); } diff --git a/ggml/src/vulkan-shaders/rope_head.comp b/ggml/src/vulkan-shaders/rope_head.comp index 96c9c4cbd..99f2fb731 100644 --- a/ggml/src/vulkan-shaders/rope_head.comp +++ b/ggml/src/vulkan-shaders/rope_head.comp @@ -29,6 +29,7 @@ layout (push_constant) uniform parameter { uint s1; uint s2; int sections[4]; + uint is_imrope; uint is_back; } p; diff --git a/ggml/src/vulkan-shaders/rope_multi.comp b/ggml/src/vulkan-shaders/rope_multi.comp index 5808710cc..fd8c59e2d 100644 --- a/ggml/src/vulkan-shaders/rope_multi.comp +++ b/ggml/src/vulkan-shaders/rope_multi.comp @@ -32,17 +32,29 @@ void main() { const uint sector = (i0 / 2) % sect_dims; float theta_base = 0.0; - if (sector < p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= p.sections[0] && sector < sec_w) { - theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + if (p.is_imrope != 0) { + if (sector % 3 == 1 && sector < 3 * p.sections[1]) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } else { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + } else { + if (sector < p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } } const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/include/llama.h b/include/llama.h index ef00a1fb6..771b43956 100644 --- a/include/llama.h +++ b/include/llama.h @@ -122,6 +122,7 @@ extern "C" { LLAMA_ROPE_TYPE_NORM = 0, LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX, LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE, + LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE, LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION, }; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 6bb16a045..5966f0579 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -27,6 +27,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN3, "qwen3" }, { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_QWEN3VL, "qwen3vl" }, + { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PLAMO, "plamo" }, @@ -110,7 +112,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, + { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 0069a8ee6..949a946e3 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -26,6 +26,8 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN3, LLM_ARCH_QWEN3MOE, + LLM_ARCH_QWEN3VL, + LLM_ARCH_QWEN3VLMOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PLAMO, @@ -99,6 +101,7 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, LLM_KV_NEXTN_PREDICT_LAYERS, + LLM_KV_NUM_DEEPSTACK_LAYERS, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 04e5a1423..10c741abc 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -240,7 +240,7 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector & ids) } ggml_tensor * llm_build_context::build_inp_pos() { - int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; + int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1; lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, int64_t(n_tokens)*n_pos_per_embd); cb(lctx.inp_pos, "inp_pos", -1); ggml_set_input(lctx.inp_pos); @@ -3551,6 +3551,288 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { return gf; } +ggml_cgraph * llm_build_context::build_qwen3vl() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds + const size_t n_deepstack_layers = hparams.n_deepstack_layers; + const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1); + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + std::vector deepstack_features(n_deepstack_layers, nullptr); + + if (batch.embd) { + // Image input: split main embd and deepstack embds + struct ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); + for (size_t i = 0; i < n_deepstack_layers; i++) { + deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); + } + inpL = inpL_main; + } + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, + model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, + 0, il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cb(Vcur, "Vcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + if (batch.embd && (size_t)il < n_deepstack_layers) { + cur = ggml_add(ctx0, cur, deepstack_features[il]); + cb(cur, "deepstack_out", il); + } + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} + +ggml_cgraph * llm_build_context::build_qwen3vlmoe() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds + const size_t n_deepstack_layers = hparams.n_deepstack_layers; + const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1); + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + std::vector deepstack_features(n_deepstack_layers, nullptr); + + if (batch.embd) { + // Image input: split main embd and deepstack embds + struct ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); + for (size_t i = 0; i < n_deepstack_layers; i++) { + deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); + } + inpL = inpL_main; + } + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, + model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, + 0, il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cb(Vcur, "Vcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = + llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, + cb, il, gf); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + if (batch.embd && (size_t)il < n_deepstack_layers) { + cur = ggml_add(ctx0, cur, deepstack_features[il]); + cb(cur, "deepstack_out", il); + } + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} + ggml_cgraph * llm_build_context::build_phi2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -8216,6 +8498,14 @@ ggml_cgraph * llm_build_context::llama_build_graph( { result = llm.build_qwen3moe(); } break; + case LLM_ARCH_QWEN3VL: + { + result = llm.build_qwen3vl(); + } break; + case LLM_ARCH_QWEN3VLMOE: + { + result = llm.build_qwen3vlmoe(); + } break; case LLM_ARCH_PHI2: { result = llm.build_phi2(); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 150f35910..ff8ce7ced 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -192,8 +192,12 @@ struct llm_build_context { ggml_cgraph * build_qwen3(); + ggml_cgraph * build_qwen3vl(); + ggml_cgraph * build_qwen3moe(); + ggml_cgraph * build_qwen3vlmoe(); + ggml_cgraph * build_phi2(); ggml_cgraph * build_phi3(); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index d706c0915..d904e90a0 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -405,12 +405,32 @@ void llm_load_hparams( } } break; - case LLM_ARCH_QWEN3: + case LLM_ARCH_QWEN3: { + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { + case 28: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_6B : e_model::MODEL_1_7B; break; + case 36: model.type = hparams.n_embd == 2560 ? e_model::MODEL_4B : e_model::MODEL_8B; break; + case 40: model.type = e_model::MODEL_14B; break; + case 64: model.type = e_model::MODEL_32B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN3VL: + { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_1_7B; break; + case 36: model.type = hparams.n_embd == 2560 ? e_model::MODEL_4B : e_model::MODEL_8B; break; + case 64: model.type = e_model::MODEL_32B; break; default: model.type = e_model::MODEL_UNKNOWN; } + // since vision model stacks deepstack features along feature dim + // we also create a fake "n_embd" for text model to be the main embd + deepstack embds + hparams.n_embd *= hparams.n_deepstack_layers + 1; } break; case LLM_ARCH_QWEN3MOE: { @@ -418,8 +438,25 @@ void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { + case 48: model.type = e_model::MODEL_30B_A3B; break; + case 94: model.type = e_model::MODEL_235B_A22B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN3VLMOE: + { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: model.type = e_model::MODEL_30B_A3B; break; + case 94: model.type = e_model::MODEL_235B_A22B; break; default: model.type = e_model::MODEL_UNKNOWN; } + // since vision model stacks deepstack features along feature dim + // we also create a fake "n_embd" for text model to be the main embd + deepstack embds + hparams.n_embd *= hparams.n_deepstack_layers + 1; } break; case LLM_ARCH_PHI2: { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 4235714ae..49c745655 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -110,7 +110,10 @@ struct llama_hparams { uint32_t n_no_rope_layer_step = 4; uint32_t n_attn_temp_floor_scale = 8192; float f_attn_temp_scale = 0.1; - + + // qwen3vl deepstack + uint32_t n_deepstack_layers = 0; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = -1; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index d27979115..6638043d1 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -224,7 +224,7 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std [[maybe_unused]] const int64_t n_layer = hparams.n_layer; \ [[maybe_unused]] const int64_t n_head = hparams.n_head(); \ [[maybe_unused]] const int64_t n_head_kv = hparams.n_head_kv(); \ - [[maybe_unused]] const int64_t n_embd = hparams.n_embd; \ + [[maybe_unused]] const int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1); /* For Qwen3-VL we need to divide by the number of deepstack layers + 1, for other models n_deepstack_layers value is 0 by default */ \ [[maybe_unused]] const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); \ [[maybe_unused]] const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); \ [[maybe_unused]] const int64_t n_embd_head_k = hparams.n_embd_head_k; \ @@ -972,8 +972,13 @@ bool create_tensors_helper::create_qwen2_moe_tensors(const LLM_TN & tn) { layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - GGML_ASSERT(n_expert > 0); - GGML_ASSERT(n_expert_used > 0); + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); + } + // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; @@ -1030,7 +1035,17 @@ bool create_tensors_helper::create_qwen3_tensors(const LLM_TN & tn) { bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) { LOADING_PRELUDE - create_embd_output(tn, n_embd, n_vocab); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -1051,8 +1066,12 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) { layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - GGML_ASSERT(n_expert > 0); - GGML_ASSERT(n_expert_used > 0); + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; @@ -2419,7 +2438,7 @@ bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias) { auto& hparams = model.hparams; const int64_t n_head = hparams.n_head(); const int64_t n_head_kv = hparams.n_head_kv(); - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1); // For Qwen3-VL we need to divide by the number of deepstack layers + 1, for other models n_deepstack_layers value is 0 by default const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_gqa = n_embd_v_gqa; @@ -2540,8 +2559,10 @@ bool create_tensors_helper::create_tensors() { case LLM_ARCH_QWEN2MOE: use_mmap_buffer = create_qwen2_moe_tensors(tn); break; case LLM_ARCH_QWEN3: + case LLM_ARCH_QWEN3VL: use_mmap_buffer = create_qwen3_tensors(tn); break; case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_QWEN3VLMOE: use_mmap_buffer = create_qwen3_moe_tensors(tn); break; case LLM_ARCH_PHI2: use_mmap_buffer = create_phi2_tensors(tn); break; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d2b92e75d..1d59cbd1c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -429,6 +429,45 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_QWEN3VL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN3VLMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PHI2, { @@ -1341,21 +1380,39 @@ const char * llama_model_type_name(e_model type) { case MODEL_80M: return "80M"; case MODEL_109M: return "109M"; case MODEL_137M: return "137M"; + case MODEL_140M: return "140M"; case MODEL_160M: return "160M"; + case MODEL_190M: return "190M"; case MODEL_220M: return "220M"; case MODEL_250M: return "250M"; + case MODEL_256M: return "256M"; case MODEL_270M: return "270M"; case MODEL_335M: return "335M"; + case MODEL_350M: return "350M"; + case MODEL_360M: return "360M"; case MODEL_410M: return "410M"; case MODEL_450M: return "450M"; + case MODEL_475M: return "475M"; + case MODEL_558M: return "558M"; + case MODEL_700M: return "700M"; case MODEL_770M: return "770M"; case MODEL_780M: return "780M"; + case MODEL_950M: return "950M"; + case MODEL_0_3B: return "0.3B"; case MODEL_0_5B: return "0.5B"; + case MODEL_0_6B: return "0.6B"; case MODEL_1B: return "1B"; + case MODEL_1_2B: return "1.2B"; case MODEL_1_3B: return "1.3B"; case MODEL_1_4B: return "1.4B"; + case MODEL_1_5B: return "1.5B"; + case MODEL_1_6B: return "1.6B"; + case MODEL_1_7B: return "1.7B"; + case MODEL_1_8B: return "1.8B"; case MODEL_2B: return "2B"; + case MODEL_2_6B: return "2.6B"; case MODEL_2_8B: return "2.8B"; + case MODEL_2_9B: return "2.9B"; case MODEL_3B: return "3B"; case MODEL_4B: return "4B"; case MODEL_6B: return "6B"; @@ -1370,17 +1427,19 @@ const char * llama_model_type_name(e_model type) { case MODEL_15B: return "15B"; case MODEL_16B: return "16B"; case MODEL_20B: return "20B"; + case MODEL_27B: return "27B"; case MODEL_30B: return "30B"; case MODEL_32B: return "32B"; case MODEL_34B: return "34B"; case MODEL_35B: return "35B"; + case MODEL_36B: return "36B"; case MODEL_40B: return "40B"; case MODEL_65B: return "65B"; case MODEL_70B: return "70B"; - case MODEL_106B_A12B: return "106B.A12B"; + case MODEL_120B: return "120B"; case MODEL_142B: return "142B"; case MODEL_236B: return "236B"; - case MODEL_355B_A32B: return "355B.A32B"; + case MODEL_290B: return "290B"; case MODEL_314B: return "314B"; case MODEL_405B: return "405B"; case MODEL_671B: return "671B"; @@ -1388,20 +1447,30 @@ const char * llama_model_type_name(e_model type) { case MODEL_MEDIUM: return "0.4B"; case MODEL_LARGE: return "0.8B"; case MODEL_XL: return "1.5B"; + case MODEL_A1_7B: return "A1.7B"; case MODEL_A2_7B: return "A2.7B"; case MODEL_8x7B: return "8x7B"; case MODEL_8x22B: return "8x22B"; case MODEL_16x12B: return "16x12B"; + case MODEL_16x3_8B: return "16x3.8B"; case MODEL_10B_128x3_66B: return "10B+128x3.66B"; case MODEL_57B_A14B: return "57B.A14B"; - case MODEL_27B: return "27B"; case MODEL_17B_16E: return "17Bx16E (Scout)"; case MODEL_17B_128E: return "17Bx128E (Maverick)"; - case MODEL_80B_A13B: return "80B.A13B"; - case MODEL_21B_A3B: return "21B.A3B"; - case MODEL_300B_A47B: return "300B.A47B"; + case MODEL_A13B: return "A13B"; + case MODEL_7B_A1B: return "7B.A1B"; + case MODEL_8B_A1B: return "8B.A1B"; case MODEL_16B_A1B: return "16B.A1B"; + case MODEL_21B_A3B: return "21B.A3B"; + case MODEL_30B_A3B: return "30B.A3B"; + case MODEL_80B_A13B: return "80B.A13B"; case MODEL_100B_A6B: return "100B.A6B"; - default: return "?B"; + case MODEL_106B_A12B: return "106B.A12B"; + case MODEL_235B_A22B: return "235B.A22B"; + case MODEL_300B_A47B: return "300B.A47B"; + case MODEL_355B_A32B: return "355B.A32B"; + case MODEL_E2B: return "E2B"; + case MODEL_E4B: return "E4B"; + default: return "?B"; } } diff --git a/src/llama-model.h b/src/llama-model.h index a26c7cb3d..991f7f9c4 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -24,22 +24,39 @@ enum e_model { MODEL_80M, MODEL_109M, MODEL_137M, + MODEL_140M, MODEL_160M, + MODEL_190M, MODEL_220M, MODEL_250M, + MODEL_256M, MODEL_270M, MODEL_335M, + MODEL_350M, + MODEL_360M, MODEL_410M, MODEL_450M, + MODEL_475M, + MODEL_558M, + MODEL_700M, MODEL_770M, MODEL_780M, + MODEL_950M, MODEL_0_3B, MODEL_0_5B, + MODEL_0_6B, MODEL_1B, + MODEL_1_2B, MODEL_1_3B, MODEL_1_4B, + MODEL_1_5B, + MODEL_1_6B, + MODEL_1_7B, + MODEL_1_8B, MODEL_2B, + MODEL_2_6B, MODEL_2_8B, + MODEL_2_9B, MODEL_3B, MODEL_4B, MODEL_6B, @@ -54,17 +71,19 @@ enum e_model { MODEL_15B, MODEL_16B, MODEL_20B, + MODEL_27B, MODEL_30B, MODEL_32B, MODEL_34B, MODEL_35B, + MODEL_36B, MODEL_40B, MODEL_65B, MODEL_70B, - MODEL_106B_A12B, + MODEL_120B, MODEL_142B, MODEL_236B, - MODEL_355B_A32B, + MODEL_290B, MODEL_314B, MODEL_405B, MODEL_671B, @@ -72,22 +91,33 @@ enum e_model { MODEL_MEDIUM, MODEL_LARGE, MODEL_XL, + MODEL_A1_7B, MODEL_A2_7B, MODEL_8x7B, MODEL_8x22B, MODEL_16x12B, + MODEL_16x3_8B, MODEL_10B_128x3_66B, - MODEL_21B_A3B, // Ernie MoE small MODEL_57B_A14B, - MODEL_27B, MODEL_17B_16E, MODEL_17B_128E, - MODEL_80B_A13B, - MODEL_300B_A47B, // Ernie MoE big + MODEL_A13B, + MODEL_7B_A1B, + MODEL_8B_A1B, MODEL_16B_A1B, + MODEL_21B_A3B, // Ernie MoE small + MODEL_30B_A3B, + MODEL_80B_A13B, MODEL_100B_A6B, + MODEL_106B_A12B, + MODEL_235B_A22B, + MODEL_300B_A47B, // Ernie MoE big + MODEL_355B_A32B, + MODEL_E2B, + MODEL_E4B, }; + struct llama_layer_nextn { struct ggml_tensor * eh_proj = nullptr; struct ggml_tensor * embed_tokens = nullptr; diff --git a/src/llama.cpp b/src/llama.cpp index 700b006d0..143421c5f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1165,6 +1165,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + // MRoPE (Multi-axis Rotary Position Embedding) sections + if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { + LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); + } LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -1230,7 +1234,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (model.arch == LLM_ARCH_QWEN3MOE || model.arch == LLM_ARCH_OPENAI_MOE) { + if (model.arch == LLM_ARCH_QWEN3MOE || model.arch == LLM_ARCH_OPENAI_MOE || model.arch == LLM_ARCH_QWEN3VLMOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -2054,7 +2058,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { auto tim1 = ggml_time_us(); #endif const int64_t n_tokens = batch.n_tokens; - const int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; + const int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1; if (batch.token && n_pos_per_embd == 4) { std::vector pos_data(n_tokens*n_pos_per_embd); for (int i = 0; i < n_tokens; ++i) { @@ -4656,6 +4660,10 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN2VL: return LLAMA_ROPE_TYPE_MROPE; + case LLM_ARCH_QWEN3VL: + case LLM_ARCH_QWEN3VLMOE: + return LLAMA_ROPE_TYPE_IMROPE; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture");