Skip to content

Commit 069f9ef

Browse files
committed
mtmd : Support jinja in libmtmd (Only for QwenVL and Qwen Omni)
1 parent 6efcd65 commit 069f9ef

File tree

10 files changed

+144
-23
lines changed

10 files changed

+144
-23
lines changed

common/chat.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1727,7 +1727,8 @@ static common_chat_params common_chat_templates_apply_jinja(
17271727
: *tmpls->template_default;
17281728
const auto & src = tmpl.source();
17291729
const auto & caps = tmpl.original_caps();
1730-
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
1730+
bool concat_text = !inputs.no_part_concat && !tmpl.original_caps().requires_typed_content;
1731+
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, concat_text);
17311732
params.add_generation_prompt = inputs.add_generation_prompt;
17321733
params.tool_choice = inputs.tool_choice;
17331734
params.enable_thinking = inputs.enable_thinking;

common/chat.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ struct common_chat_templates_inputs {
127127
bool enable_thinking = true;
128128
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
129129
std::map<std::string, std::string> chat_template_kwargs;
130+
131+
//If true, the jinja won't concat content parts into single part. That's useful for media parts
132+
bool no_part_concat = false;
130133
};
131134

132135
struct common_chat_params {

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,9 @@ extern "C" {
10571057
LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab);
10581058
LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab);
10591059

1060+
LLAMA_API llama_token llama_vocab_image_token(const struct llama_vocab * vocab);
1061+
LLAMA_API llama_token llama_vocab_audio_token(const struct llama_vocab * vocab);
1062+
10601063
DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead");
10611064
DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead");
10621065
DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead");

src/llama-arch.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
217217
{ LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" },
218218
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
219219
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
220+
{LLM_KV_TOKENIZER_IMAGE_ID, "tokenizer.ggml.image_token_id" },
221+
{LLM_KV_TOKENIZER_AUDIO_ID, "tokenizer.ggml.audio_token_id" },
220222

221223
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
222224
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ enum llm_kv {
225225

226226
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
227227

228+
LLM_KV_TOKENIZER_IMAGE_ID,
229+
LLM_KV_TOKENIZER_AUDIO_ID,
230+
228231
// deprecated:
229232
LLM_KV_TOKENIZER_PREFIX_ID,
230233
LLM_KV_TOKENIZER_SUFFIX_ID,

src/llama-vocab.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,6 +1265,8 @@ struct llama_vocab::impl {
12651265
llama_token special_fim_pad_id = LLAMA_TOKEN_NULL;
12661266
llama_token special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
12671267
llama_token special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
1268+
llama_token special_image_id = LLAMA_TOKEN_NULL;
1269+
llama_token special_audio_id = LLAMA_TOKEN_NULL;
12681270

12691271
// tokenizer flags
12701272
bool add_space_prefix = false;
@@ -1695,6 +1697,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
16951697
ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false);
16961698
}
16971699

1700+
const int image_idx = gguf_find_key(ctx,kv(LLM_KV_TOKENIZER_IMAGE_ID).c_str());
1701+
if (image_idx != -1) {
1702+
special_image_id=gguf_get_val_u32(ctx,image_idx);
1703+
}
1704+
const int audio_idx = gguf_find_key(ctx,kv(LLM_KV_TOKENIZER_AUDIO_ID).c_str());
1705+
if (audio_idx != -1) {
1706+
special_audio_id=gguf_get_val_u32(ctx,audio_idx);
1707+
}
16981708
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
16991709
if (token_idx == -1) {
17001710
throw std::runtime_error("cannot find tokenizer vocab in model file\n");
@@ -1730,6 +1740,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
17301740
token_data.score = scores ? scores[i] : 0.0f;
17311741
token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;
17321742

1743+
17331744
if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file
17341745
switch(toktypes[i]) {
17351746
case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break;
@@ -1790,6 +1801,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
17901801
{ LLM_KV_TOKENIZER_FIM_PAD_ID, special_fim_pad_id },
17911802
{ LLM_KV_TOKENIZER_FIM_REP_ID, special_fim_rep_id },
17921803
{ LLM_KV_TOKENIZER_FIM_SEP_ID, special_fim_sep_id },
1804+
{ LLM_KV_TOKENIZER_IMAGE_ID, special_image_id },
1805+
{ LLM_KV_TOKENIZER_AUDIO_ID, special_audio_id },
17931806

17941807
// deprecated
17951808
{ LLM_KV_TOKENIZER_PREFIX_ID, special_fim_pre_id },
@@ -1867,6 +1880,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
18671880
}
18681881
}
18691882
}
1883+
if (special_image_id==LLAMA_TOKEN_NULL) {
1884+
if (t.first=="<|IMAGE|>" || t.first=="<IMAGE>") {
1885+
special_image_id=t.second;
1886+
}
1887+
}
1888+
if (special_audio_id==LLAMA_TOKEN_NULL) {
1889+
if (t.first=="<|AUDIO|>" || t.first=="<AUDIO>") {
1890+
special_audio_id=t.second;
1891+
}
1892+
}
18701893

18711894
// find FIM_PRE token: "<|fim_prefix|>", "<fim-prefix>", "<PRE>", etc.
18721895
if (special_fim_pre_id == LLAMA_TOKEN_NULL) {
@@ -3003,6 +3026,14 @@ llama_token llama_vocab::token_fim_sep() const {
30033026
return pimpl->special_fim_sep_id;
30043027
}
30053028

3029+
llama_token llama_vocab::token_image() const {
3030+
return pimpl->special_image_id;
3031+
}
3032+
3033+
llama_token llama_vocab::token_audio() const {
3034+
return pimpl->special_audio_id;
3035+
}
3036+
30063037
bool llama_vocab::get_add_space_prefix() const {
30073038
return pimpl->add_space_prefix;
30083039
}
@@ -3243,6 +3274,14 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
32433274
return vocab->token_fim_sep();
32443275
}
32453276

3277+
llama_token llama_vocab_image_token(const struct llama_vocab * vocab) {
3278+
return vocab->token_image();
3279+
}
3280+
3281+
llama_token llama_vocab_audio_token(const struct llama_vocab * vocab) {
3282+
return vocab->token_audio();
3283+
}
3284+
32463285
// deprecated
32473286
const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
32483287
return llama_vocab_get_text(vocab, token);

src/llama-vocab.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ struct llama_vocab {
7171
llama_token token_fim_rep() const;
7272
llama_token token_fim_sep() const;
7373

74+
llama_token token_image() const;
75+
llama_token token_audio() const;
7476
bool get_add_space_prefix () const;
7577
bool get_add_bos () const;
7678
bool get_add_eos () const;

tools/mtmd/clip.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ enum patch_merge_type {
164164

165165
struct clip_hparams {
166166
int32_t image_size;
167-
int32_t patch_size;
167+
int32_t patch_size=INT_MAX;
168168
int32_t n_embd;
169169
int32_t n_ff;
170170
int32_t projection_dim;

tools/mtmd/mtmd-cli.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ struct mtmd_cli_context {
160160
}
161161
};
162162

163-
static int generate_response(mtmd_cli_context & ctx, int n_predict) {
163+
static std::string generate_response(mtmd_cli_context & ctx, int n_predict) {
164164
llama_tokens generated_tokens;
165+
std::string response = "";
165166
for (int i = 0; i < n_predict; i++) {
166167
if (i > n_predict || !g_is_generating || g_is_interrupted) {
167168
LOG("\n");
@@ -176,8 +177,9 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
176177
LOG("\n");
177178
break; // end of generation
178179
}
179-
180-
LOG("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
180+
std::string piece=common_token_to_piece(ctx.lctx, token_id);
181+
LOG("%s", piece.c_str());
182+
response += piece;
181183
fflush(stdout);
182184

183185
if (g_is_interrupted) {
@@ -190,17 +192,18 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
190192
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
191193
if (llama_decode(ctx.lctx, ctx.batch)) {
192194
LOG_ERR("failed to decode token\n");
193-
return 1;
195+
return "";
194196
}
195197
}
196-
return 0;
198+
return response;
197199
}
198200

199-
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
201+
static int eval_message(mtmd_cli_context & ctx, const std::vector<common_chat_msg> & messages, bool add_bos = false) {
200202
common_chat_templates_inputs tmpl_inputs;
201-
tmpl_inputs.messages = {msg};
203+
tmpl_inputs.messages = messages;
202204
tmpl_inputs.add_generation_prompt = true;
203-
tmpl_inputs.use_jinja = false; // jinja is buggy here
205+
tmpl_inputs.no_part_concat=true;
206+
tmpl_inputs.use_jinja = true; // jinja is bughigy here
204207
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
205208
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
206209

@@ -303,10 +306,10 @@ int main(int argc, char ** argv) {
303306
return 1; // error is already printed by libmtmd
304307
}
305308
}
306-
if (eval_message(ctx, msg, true)) {
309+
if (eval_message(ctx,{msg} , true)) {
307310
return 1;
308311
}
309-
if (!g_is_interrupted && generate_response(ctx, n_predict)) {
312+
if (!g_is_interrupted && generate_response(ctx, n_predict).empty()) {
310313
return 1;
311314
}
312315

@@ -324,7 +327,7 @@ int main(int argc, char ** argv) {
324327

325328
bool is_first_msg = true;
326329
std::string content;
327-
330+
std::vector<common_chat_msg> messages;
328331
while (!g_is_interrupted) {
329332
g_is_generating = false;
330333
LOG("\n> ");
@@ -357,24 +360,31 @@ int main(int argc, char ** argv) {
357360
std::string media_path = line.substr(7);
358361
if (ctx.load_media(media_path)) {
359362
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : "audio");
360-
content += mtmd_default_marker();
363+
//content += mtmd_default_marker();
364+
common_chat_msg msg;
365+
msg.content_parts.push_back({"image",""});
366+
messages.push_back(std::move(msg));
361367
}
362368
// else, error is already printed by libmtmd
363369
continue;
364-
} else {
365-
content += line;
366370
}
367371
common_chat_msg msg;
368372
msg.role = "user";
369-
msg.content = content;
370-
int ret = eval_message(ctx, msg, is_first_msg);
373+
msg.content = line;
374+
messages.push_back(std::move(msg));
375+
int ret = eval_message(ctx, messages, is_first_msg);
371376
if (ret) {
372377
return 1;
373378
}
374379
if (g_is_interrupted) break;
375-
if (generate_response(ctx, n_predict)) {
380+
auto response=generate_response(ctx, n_predict);
381+
if (response.empty()) {
376382
return 1;
377383
}
384+
common_chat_msg response_message;
385+
response_message.role = "system";
386+
response_message.content = response;
387+
messages.push_back(response_message);
378388
content.clear();
379389
is_first_msg = false;
380390
}

tools/mtmd/mtmd.cpp

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,10 @@ struct mtmd_context {
254254

255255
} else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL) {
256256
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
257-
img_beg = "<|vision_start|>";
258-
img_end = "<|vision_end|>";
257+
//There is a valid reason why they are commented here. QWENVL and Qwen Omni has their own tokens for image.\
258+
// The jinja produce something like that <|vision_bos|><|IMAGE|><|vision_eos|>.
259+
// img_beg = "<|vision_start|>";
260+
// img_end = "<|vision_end|>";
259261

260262
} else if (proj == PROJECTOR_TYPE_LLAMA4) {
261263
// (more details in mtmd_context constructor)
@@ -385,10 +387,21 @@ struct mtmd_tokenizer {
385387

386388
int32_t tokenize(mtmd_input_chunks * output) {
387389
cur.entries.clear();
388-
std::vector<std::string> parts = split_text(input_text, ctx->media_marker);
390+
389391
size_t i_bm = 0; // index of the current bitmap
392+
llama_token imageTokenID =llama_vocab_image_token(vocab);
393+
std::string imageToken;
394+
std::vector<std::string> delimiters;
395+
delimiters.push_back(ctx->media_marker);
396+
397+
if (imageTokenID!=LLAMA_TOKEN_NULL) {
398+
imageToken = llama_vocab_get_text(vocab,imageTokenID);
399+
delimiters.push_back(imageToken);
400+
}
401+
402+
std::vector<std::string> parts = split_text_multi(input_text, delimiters);
390403
for (auto & part : parts) {
391-
if (part == ctx->media_marker) {
404+
if (part == ctx->media_marker || part==imageToken) {
392405
// this is a marker, we should add the next bitmap
393406
if (i_bm >= bitmaps.size()) {
394407
LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n",
@@ -707,6 +720,51 @@ struct mtmd_tokenizer {
707720
return result;
708721
}
709722

723+
static std::vector<std::string> split_text_multi(const std::string& input,
724+
const std::vector<std::string>& delimiters) {
725+
std::vector<std::string> result;
726+
if (input.empty()) {
727+
return result;
728+
}
729+
730+
size_t pos = 0;
731+
while (pos < input.length()) {
732+
// Find the earliest occurring delimiter
733+
size_t best_match_pos = std::string::npos;
734+
std::string best_delimiter;
735+
736+
for (const auto& delimiter : delimiters) {
737+
size_t match_pos = input.find(delimiter, pos);
738+
if (match_pos != std::string::npos &&
739+
(best_match_pos == std::string::npos || match_pos < best_match_pos)) {
740+
best_match_pos = match_pos;
741+
best_delimiter = delimiter;
742+
}
743+
}
744+
745+
if (best_match_pos == std::string::npos) {
746+
// No more delimiters found, add remaining text
747+
if (pos < input.length()) {
748+
result.push_back(input.substr(pos));
749+
}
750+
break;
751+
}
752+
753+
// Add text before delimiter (if any)
754+
if (best_match_pos > pos) {
755+
result.push_back(input.substr(pos, best_match_pos - pos));
756+
}
757+
758+
// Add the delimiter itself
759+
result.push_back(best_delimiter);
760+
761+
// Move past the delimiter
762+
pos = best_match_pos + best_delimiter.length();
763+
}
764+
765+
return result;
766+
}
767+
710768
// copied from common_tokenize
711769
static std::vector<llama_token> mtmd_tokenize_text_internal(
712770
const struct llama_vocab * vocab,

0 commit comments

Comments
 (0)