Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/llava/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);

CLIP_API struct clip_image_size * clip_image_size_init();
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
CLIP_API struct clip_image_f32 * clip_image_f32_init();
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava
CLIP_API struct clip_image_size * clip_image_size_init(void);
CLIP_API struct clip_image_u8 * clip_image_u8_init (void);
CLIP_API struct clip_image_f32 * clip_image_f32_init(void);
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava

// nx, ny are the output image dimensions
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
Expand Down
47 changes: 30 additions & 17 deletions examples/llava/mtmd-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ static void sigint_handler(int signo) {
#endif

struct mtmd_cli_context {
mtmd_context_ptr ctx_vision;
mtmd::context_ptr ctx_vision;
common_init_result llama_init;

llama_model * model;
Expand Down Expand Up @@ -112,12 +112,12 @@ struct mtmd_cli_context {

void init_vision_context(common_params & params) {
const char * clip_path = params.mmproj.path.c_str();
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
/* use_gpu */ params.mmproj_use_gpu,
/* timings */ true,
/* n_threads */ params.cpuparams.n_threads,
/* verbosity */ params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO,
}));
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params.mmproj_use_gpu;
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_vision.get()) {
LOG_ERR("Failed to load vision model from %s\n", clip_path);
exit(1);
Expand Down Expand Up @@ -173,7 +173,7 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
}

static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
std::vector<mtmd_bitmap> bitmaps;
mtmd::bitmaps bitmaps;

common_chat_templates_inputs tmpl_inputs;
tmpl_inputs.messages = {msg};
Expand All @@ -183,34 +183,47 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());

for (auto & fname : images_fname) {
mtmd_bitmap bitmap;
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str()));
if (!bmp.ptr) {
LOG_ERR("Unable to load image %s\n", fname.c_str());
return 2; // image not found
}
bitmaps.push_back(std::move(bitmap));
bitmaps.entries.push_back(std::move(bmp));
}

mtmd_input_text text;
text.text = formatted_chat.prompt;
text.text = formatted_chat.prompt.c_str();
text.add_special = add_bos;
text.parse_special = true;
mtmd_input_chunks chunks;

if (g_is_interrupted) return 0;

int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
mtmd::input_chunks chunks;
auto bitmaps_c_ptr = bitmaps.c_ptr();
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(),
chunks.ptr.get(), // output
&text, // text
bitmaps_c_ptr.data(),
bitmaps_c_ptr.size());
if (res != 0) {
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
return 1;
}

if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
llama_pos new_n_past;
if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(),
ctx.lctx, // lctx
chunks.ptr.get(), // chunks
ctx.n_past, // n_past
0, // seq_id
ctx.n_batch, // n_batch
true, // logits_last
&new_n_past)) {
LOG_ERR("Unable to eval prompt\n");
return 1;
}

ctx.n_past += mtmd_helper_get_n_pos(chunks);
ctx.n_past = new_n_past;

return 0;
}
Expand Down Expand Up @@ -241,7 +254,7 @@ int main(int argc, char ** argv) {
struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;

// ctrl+C handling
// Ctrl+C handling
{
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
Expand Down
Loading
Loading