Skip to content

Commit a9ef623

Browse files
committed
correct pre/postfix
1 parent 7cc4108 commit a9ef623

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

examples/llava/clip-impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,9 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
326326
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
327327
}
328328
}
329+
330+
//
331+
// API used internally with llava2
332+
//
333+
334+
projector_type clip_get_projector_type(const struct clip_ctx * ctx);

examples/llava/clip.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2884,3 +2884,11 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
28842884
clip_image_encode(ctx, n_threads, &clip_img, vec);
28852885
return true;
28862886
}
2887+
2888+
//
2889+
// API used internally with llava2
2890+
//
2891+
2892+
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
2893+
return ctx->proj_type;
2894+
}

examples/llava/llava2.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,17 @@ int32_t llava2_tokenize(llava2_context_ptr & ctx,
9797
const std::vector<llava2_bitmap> & bitmaps) {
9898
auto vocab = llama_model_get_vocab(ctx->text_model);
9999

100+
std::string prompt_modified(prompt);
101+
std::string marker_modified(ctx->image_marker);
102+
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
103+
// a bit hacky here, but works for now
104+
// for some models, we need to add prefix and suffix to the image embeddings
105+
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
106+
// <start_of_image> ... (image embeddings) ... <end_of_image>
107+
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
108+
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
109+
}
110+
100111
std::vector<std::string> parts = string_split_str(prompt, ctx->image_marker);
101112
output.clear();
102113
output.reserve(parts.size());

0 commit comments

Comments
 (0)