Skip to content

Commit 96bf95e

Browse files
committed
migrated gemma3 to llava2
1 parent 235340d commit 96bf95e

File tree

5 files changed

+257
-117
lines changed

5 files changed

+257
-117
lines changed

examples/llava/CMakeLists.txt

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# llava (legacy)
2+
13
add_library(llava OBJECT
24
llava.cpp
35
llava.h
@@ -22,12 +24,40 @@ if (BUILD_SHARED_LIBS)
2224
install(TARGETS llava_shared LIBRARY)
2325
endif()
2426

27+
# llava2
28+
29+
add_library(llava2 OBJECT
30+
llava2.cpp
31+
llava2.h
32+
clip.cpp
33+
clip.h
34+
clip-impl.h
35+
)
36+
37+
target_link_libraries(llava2 PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
38+
39+
target_include_directories(llava2 PUBLIC .)
40+
target_include_directories(llava2 PUBLIC ../..)
41+
42+
target_compile_features(llava2 PRIVATE cxx_std_17)
43+
44+
add_library(llava2_static STATIC $<TARGET_OBJECTS:llava2>)
45+
if (BUILD_SHARED_LIBS)
46+
set_target_properties(llava2 PROPERTIES POSITION_INDEPENDENT_CODE ON)
47+
target_compile_definitions(llava2 PRIVATE LLAMA_SHARED LLAMA_BUILD)
48+
add_library(llava2_shared SHARED $<TARGET_OBJECTS:llava2>)
49+
target_link_libraries(llava2_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
50+
install(TARGETS llava2_shared LIBRARY)
51+
endif()
52+
2553
if (NOT MSVC)
2654
target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
55+
target_compile_options(llava2 PRIVATE -Wno-cast-qual) # stb_image.h
2756
endif()
2857

2958
if(TARGET BUILD_INFO)
3059
add_dependencies(llava BUILD_INFO)
60+
add_dependencies(llava2 BUILD_INFO)
3161
endif()
3262

3363
set(TARGET llama-llava-cli)
@@ -55,7 +85,7 @@ set(TARGET llama-gemma3-cli)
5585
add_executable(${TARGET} gemma3-cli.cpp)
5686
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
5787
install(TARGETS ${TARGET} RUNTIME)
58-
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
88+
target_link_libraries(${TARGET} PRIVATE common llava2 ${CMAKE_THREAD_LIBS_INIT})
5989
target_compile_features(${TARGET} PRIVATE cxx_std_17)
6090

6191
set(TARGET llama-llava-clip-quantize-cli)

examples/llava/clip.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,6 +2330,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
23302330
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
23312331
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
23322332
n_patches = x_patch * y_patch;
2333+
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
2334+
n_patches = 256;
23332335
}
23342336

23352337
return n_patches;

examples/llava/gemma3-cli.cpp

Lines changed: 76 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "llama.h"
66
#include "ggml.h"
77
#include "console.h"
8+
#include "chat.h"
89
#include "llava2.h"
910

1011
#include <vector>
@@ -56,13 +57,18 @@ static void sigint_handler(int signo) {
5657
#endif
5758

5859
struct gemma3_context {
59-
llava2_context_ptr ctx_llava2;
60+
llava2_context_ptr ctx_vision;
6061
common_init_result llama_init;
6162

6263
llama_model * model;
6364
llama_context * lctx;
6465
const llama_vocab * vocab;
6566
llama_batch batch;
67+
int n_batch;
68+
69+
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
70+
// so here we don't need to keep track of chat history
71+
common_chat_templates_ptr tmpls;
6672

6773
int n_threads = 1;
6874
llama_pos n_past = 0;
@@ -73,18 +79,20 @@ struct gemma3_context {
7379
vocab = llama_model_get_vocab(model);
7480
n_threads = params.cpuparams.n_threads;
7581
batch = llama_batch_init(params.n_batch, 0, 1);
76-
init_clip_model(params);
82+
n_batch = params.n_batch;
83+
tmpls = common_chat_templates_init(model, params.chat_template);
84+
init_vision_context(params);
7785
}
7886

79-
void init_clip_model(common_params & params) {
87+
void init_vision_context(common_params & params) {
8088
const char * clip_path = params.mmproj.path.c_str();
81-
ctx_llava2 = llava2_init_from_file(clip_path, model, llava2_context_params{
89+
ctx_vision = llava2_init_from_file(clip_path, model, llava2_context_params{
8290
/* use_gpu */ true,
8391
/* n_threads */ params.cpuparams.n_threads,
8492
/* verbosity */ GGML_LOG_LEVEL_INFO,
8593
});
86-
if (!ctx_llava2.get()) {
87-
LOG_ERR("Failed to load CLIP model from %s\n", clip_path);
94+
if (!ctx_vision.get()) {
95+
LOG_ERR("Failed to load vision model from %s\n", clip_path);
8896
exit(1);
8997
}
9098
}
@@ -123,77 +131,6 @@ struct decode_embd_batch {
123131
}
124132
};
125133

126-
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
127-
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
128-
common_batch_clear(ctx.batch);
129-
for (llama_token & t : tokens) {
130-
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
131-
}
132-
if (logits_last) {
133-
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
134-
}
135-
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
136-
if (llama_decode(ctx.lctx, ctx.batch)) {
137-
LOG_ERR("Failed to decode text\n");
138-
return 1;
139-
}
140-
return 0;
141-
}
142-
143-
static int eval_image(gemma3_context & ctx, std::string & fname) {
144-
std::vector<float> image_embd_v;
145-
int n_embd = llama_model_n_embd(ctx.model);
146-
int n_tokens = 256;
147-
image_embd_v.resize(n_tokens * n_embd);
148-
149-
bool ok;
150-
struct clip_image_u8 * img_u8 = clip_image_u8_init();
151-
ok = clip_image_load_from_file(fname.c_str(), img_u8);
152-
if (!ok) {
153-
LOG_ERR("Unable to load image %s\n", fname.c_str());
154-
clip_image_u8_free(img_u8);
155-
return 2; // non-fatal error
156-
}
157-
158-
clip_image_f32_batch batch_f32;
159-
ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
160-
if (!ok) {
161-
LOG_ERR("Unable to preprocess image\n");
162-
clip_image_f32_batch_free(&batch_f32);
163-
clip_image_u8_free(img_u8);
164-
return 1;
165-
}
166-
167-
int64_t t0 = ggml_time_ms();
168-
LOG("Encoding image %s\n", fname.c_str());
169-
ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data());
170-
if (!ok) {
171-
LOG_ERR("Unable to encode image\n");
172-
clip_image_f32_batch_free(&batch_f32);
173-
clip_image_u8_free(img_u8);
174-
return 1;
175-
}
176-
LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
177-
178-
clip_image_f32_batch_free(&batch_f32);
179-
clip_image_u8_free(img_u8);
180-
181-
// decode image embeddings
182-
int64_t t1 = ggml_time_ms();
183-
eval_text(ctx, "<start_of_image>");
184-
llama_set_causal_attn(ctx.lctx, false);
185-
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
186-
if (llama_decode(ctx.lctx, batch_img.batch)) {
187-
LOG_ERR("failed to decode image\n");
188-
return 1;
189-
}
190-
ctx.n_past += n_tokens;
191-
llama_set_causal_attn(ctx.lctx, true);
192-
eval_text(ctx, "<end_of_image>");
193-
LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
194-
return 0;
195-
}
196-
197134
static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
198135
for (int i = 0; i < n_predict; i++) {
199136
if (i > n_predict || !g_is_generating) {
@@ -223,6 +160,41 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
223160
return 0;
224161
}
225162

163+
static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
164+
std::vector<llava2_bitmap> bitmaps;
165+
166+
common_chat_templates_inputs tmpl_inputs;
167+
tmpl_inputs.messages = {msg};
168+
tmpl_inputs.add_generation_prompt = true;
169+
tmpl_inputs.use_jinja = false; // jinja is buggy here
170+
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
171+
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
172+
173+
for (auto & fname : images_fname) {
174+
llava2_bitmap bitmap;
175+
if (llava2_bitmap_init_from_file(fname.c_str(), bitmap)) {
176+
LOG_ERR("Unable to load image %s\n", fname.c_str());
177+
return 2; // image not found
178+
}
179+
bitmaps.push_back(std::move(bitmap));
180+
}
181+
182+
std::vector<llava2_input_chunk> chunks;
183+
if (llava2_tokenize(ctx.ctx_vision, chunks, formatted_chat.prompt, add_bos, true, bitmaps)) {
184+
LOG_ERR("Unable to tokenize prompt\n");
185+
return 1;
186+
}
187+
188+
if (llava2_helper_eval(ctx.ctx_vision, ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
189+
LOG_ERR("Unable to eval prompt\n");
190+
return 1;
191+
}
192+
193+
ctx.n_past += llava2_helper_get_n_tokens(chunks);
194+
195+
return 0;
196+
}
197+
226198
int main(int argc, char ** argv) {
227199
ggml_time_init();
228200

@@ -264,22 +236,15 @@ int main(int argc, char ** argv) {
264236
#endif
265237
}
266238

267-
if (eval_text(ctx, "<bos>")) {
268-
return 1;
269-
}
270-
271239
if (is_single_turn) {
272240
g_is_generating = true;
273-
std::string prompt = "<start_of_turn>user\n<image>" + params.prompt + "<end_of_turn><start_of_turn>model\n";
274-
if (eval_text(ctx, "<start_of_turn>user\n")) {
275-
return 1;
276-
}
277-
for (auto & fname : params.image) {
278-
if (eval_image(ctx, fname)) {
279-
return 1;
280-
}
241+
if (params.prompt.find("<__image__>") == std::string::npos) {
242+
params.prompt += " <__image__>";
281243
}
282-
if (eval_text(ctx, params.prompt + "<end_of_turn><start_of_turn>model\n", true)) {
244+
common_chat_msg msg;
245+
msg.role = "user";
246+
msg.content = params.prompt;
247+
if (eval_message(ctx, msg, params.image, true)) {
283248
return 1;
284249
}
285250
if (generate_response(ctx, smpl, n_predict)) {
@@ -293,9 +258,9 @@ int main(int argc, char ** argv) {
293258
LOG("\n /quit or /exit exit the program");
294259
LOG("\n");
295260

296-
if (eval_text(ctx, "<start_of_turn>user\n")) {
297-
return 1;
298-
}
261+
bool is_first_msg = true;
262+
std::vector<std::string> images_fname;
263+
std::string content;
299264

300265
while (true) {
301266
g_is_generating = false;
@@ -320,24 +285,31 @@ int main(int argc, char ** argv) {
320285
g_is_generating = true;
321286
if (line.find("/image") == 0) {
322287
std::string image = line.substr(7);
323-
int res = eval_image(ctx, image);
324-
if (res == 2) {
325-
continue; // image not found
326-
}
327-
if (res) {
328-
return 1;
329-
}
288+
images_fname.push_back(string_strip(image));
289+
content += "<__image__>";
330290
continue;
291+
} else {
292+
content += line;
331293
}
332-
if (eval_text(ctx, line + "<end_of_turn><start_of_turn>model\n", true)) {
333-
return 1;
294+
common_chat_msg msg;
295+
msg.role = "user";
296+
msg.content = content;
297+
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
298+
if (ret == 2) {
299+
// non-fatal error
300+
images_fname.clear();
301+
content.clear();
302+
continue;
334303
}
335-
if (generate_response(ctx, smpl, n_predict)) {
304+
if (ret) {
336305
return 1;
337306
}
338-
if (eval_text(ctx, "<end_of_turn><start_of_turn>user\n")) {
307+
if (generate_response(ctx, smpl, n_predict)) {
339308
return 1;
340309
}
310+
images_fname.clear();
311+
content.clear();
312+
is_first_msg = false;
341313
}
342314
}
343315

0 commit comments

Comments
 (0)