Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"

#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
Expand All @@ -59,6 +60,7 @@
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
#define TN_PATCH_BIAS "v.patch_embd.bias"
#define TN_ATTN_QKV "%s.blk.%d.attn_qkv.%s"
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
Expand Down Expand Up @@ -89,6 +91,9 @@
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
#define TN_DEEPSTACK_NORM "v.deepstack.%d.norm.%s" // qwen3vl deepstack
#define TN_DEEPSTACK_FC1 "v.deepstack.%d.fc1.%s" // qwen3vl deepstack
#define TN_DEEPSTACK_FC2 "v.deepstack.%d.fc2.%s" // qwen3vl deepstack

// mimicpmv
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
Expand Down Expand Up @@ -123,6 +128,7 @@ enum projector_type {
PROJECTOR_TYPE_MINICPMV,
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_QWEN2VL,
PROJECTOR_TYPE_QWEN3VL,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
Expand All @@ -146,6 +152,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
{ PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
Expand Down
254 changes: 247 additions & 7 deletions examples/mtmd/clip.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/mtmd/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct
int clip_is_minicpmv(const struct clip_ctx * ctx);
bool clip_is_glm(const struct clip_ctx * ctx);
bool clip_is_qwen2vl(const struct clip_ctx * ctx);
bool clip_is_qwen3vl(const struct clip_ctx * ctx);
bool clip_is_llava(const struct clip_ctx * ctx);
bool clip_is_gemma3(const struct clip_ctx * ctx);

Expand Down
2 changes: 1 addition & 1 deletion examples/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ struct mtmd_context {
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
img_end = "[IMG_END]";

} else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL) {
} else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) {
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
img_beg = "<|vision_start|>";
img_end = "<|vision_end|>";
Expand Down
1 change: 1 addition & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@
#define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000

#define GGML_MROPE_SECTIONS 4

Expand Down
45 changes: 28 additions & 17 deletions ggml/src/ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (i0 >= ne0) {
Expand All @@ -152,17 +152,27 @@ static __global__ void rope_multi(
const int sector = (i0 / 2) % sect_dims;

float theta_base = 0.0;
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
}

const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
Expand Down Expand Up @@ -276,7 +286,7 @@ template<bool forward, typename T>
static void rope_multi_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
Expand All @@ -287,11 +297,11 @@ static void rope_multi_cuda(
if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
}
}

Expand Down Expand Up @@ -369,6 +379,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)

const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;

if (is_mrope) {
Expand Down Expand Up @@ -406,11 +417,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else {
GGML_ABORT("fatal error");
}
Expand Down
8 changes: 6 additions & 2 deletions ggml/src/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ struct vk_op_rope_push_constants {
uint32_t s1;
uint32_t s2;
int32_t sections[4];
uint32_t is_imrope;
uint32_t is_back;
};

Expand Down Expand Up @@ -6754,6 +6755,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
const int mode = ((const int32_t *) dst->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;

if (is_neox) {
Expand All @@ -6763,7 +6765,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_neox_f16;
}
} else if (is_mrope && !is_vision) {
} else if ((is_mrope || is_imrope) && !is_vision) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_multi_f32;
}
Expand Down Expand Up @@ -7970,6 +7972,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
}

const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;

float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);

Expand All @@ -7982,7 +7986,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
sections[0], sections[1], sections[2], sections[3], backprop
sections[0], sections[1], sections[2], sections[3], is_imrope, backprop
}, dryrun);
}

Expand Down
36 changes: 25 additions & 11 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -18339,7 +18339,7 @@ static void ggml_rope_cache_init(
}

static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
Expand Down Expand Up @@ -18374,14 +18374,26 @@ static void ggml_mrope_cache_init(
}

float theta = theta_t;
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
if (is_imrope) { // qwen3vl apply interleaved mrope
if (sector % 3 == 1 && sector < 3 * sections[1]) {
theta = theta_h;
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
theta = theta_w;
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
theta = theta_t;
} else {
theta = theta_e;
}
} else {
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
}

rope_yarn(
Expand Down Expand Up @@ -18454,6 +18466,7 @@ static void ggml_compute_forward_rope_f32(

const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;

if (is_mrope) {
Expand Down Expand Up @@ -18492,7 +18505,7 @@ static void ggml_compute_forward_rope_f32(
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}

Expand Down Expand Up @@ -18640,6 +18653,7 @@ static void ggml_compute_forward_rope_f16(

const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;

if (is_mrope) {
Expand Down Expand Up @@ -18678,7 +18692,7 @@ static void ggml_compute_forward_rope_f16(
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}

Expand Down
1 change: 1 addition & 0 deletions ggml/src/vulkan-shaders/rope_head.comp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ layout (push_constant) uniform parameter {
uint s1;
uint s2;
int sections[4];
uint is_imrope;
uint is_back;
} p;

Expand Down
34 changes: 23 additions & 11 deletions ggml/src/vulkan-shaders/rope_multi.comp
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,29 @@ void main() {
const uint sector = (i0 / 2) % sect_dims;

float theta_base = 0.0;
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
if (p.is_imrope != 0) {
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
} else {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
} else {
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
}

const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ extern "C" {
LLAMA_ROPE_TYPE_NORM = 0,
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE,
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
};

Expand Down
5 changes: 4 additions & 1 deletion src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PLAMO, "plamo" },
Expand Down Expand Up @@ -110,7 +112,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
{ LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ enum llm_arch {
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PLAMO,
Expand Down Expand Up @@ -99,6 +101,7 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_NEXTN_PREDICT_LAYERS,
LLM_KV_NUM_DEEPSTACK_LAYERS,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
Expand Down
Loading