Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/llava/gemma3-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct gemma3_context {
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
/* use_gpu */ true,
/* timings */ true,
/* hash */ false,
/* n_threads */ params.cpuparams.n_threads,
/* verbosity */ GGML_LOG_LEVEL_INFO,
}));
Expand Down
68 changes: 62 additions & 6 deletions examples/llava/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,22 @@ struct mtmd_context {
struct clip_ctx * ctx_clip;
const struct llama_model * text_model;
std::vector<float> image_embd_v; // image embedding vector

bool print_timings;
int n_threads;
std::string image_marker;
bool calc_image_hash;

// TODO @ngxson : add timings

mtmd_context(const char * mmproj_fname,
const llama_model * text_model,
const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
const mtmd_context_params & ctx_params) :
print_timings (ctx_params.print_timings),
n_threads (ctx_params.n_threads),
image_marker (ctx_params.image_marker),
calc_image_hash(ctx_params.calc_image_hash)
{
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
Expand All @@ -49,6 +56,7 @@ struct mtmd_image_tokens {
uint32_t ny; // number of tokens in y direction
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
size_t image_hash = 0; // hash of the image, useful for KV cache tracking
};

mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
Expand Down Expand Up @@ -88,6 +96,16 @@ static std::vector<llama_token> mtmd_tokenize_text_internal(
return result;
}

static uint64_t hash_vector_float(const std::vector<float> & vec) {
uint64_t seed = vec.size();
std::hash<float> hasher;
for (float val : vec) {
// inspired by boost::hash_combine
seed ^= hasher(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}

mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
const mtmd_input_text & text,
const std::vector<mtmd_bitmap> & bitmaps) {
Expand Down Expand Up @@ -153,6 +171,11 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
image_tokens->ny = 1; // TODO
image_tokens->batch_f32 = std::move(batch_f32);

// optionally calculate the hash
if (ctx->calc_image_hash) {
image_tokens->image_hash = hash_vector_float(image_tokens->batch_f32.entries[0]->buf);
}

mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
Expand All @@ -166,15 +189,40 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
return output;
}

void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
for (auto & chunk : *chunks) {
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
delete chunk.tokens_image;
void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
if (image_tokens) {
delete image_tokens;
}
}

void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images) {
if (free_images) {
for (auto & chunk : *chunks) {
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
mtmd_image_tokens_free(chunk.tokens_image);
chunk.tokens_image = nullptr;
}
}
}
delete chunks;
}

size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
return image_tokens->n_tokens();
}

size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
return image_tokens->nx;
}

size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
return image_tokens->ny;
}

uint64_t mtmd_image_tokens_get_hash(const mtmd_image_tokens * image_tokens) {
return image_tokens->image_hash;
}

int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
Expand Down Expand Up @@ -289,7 +337,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
}

int32_t n_tokens = chunk.tokens_image->n_tokens();
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image);
float * embd = mtmd_get_output_embd(ctx);
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
int64_t t1 = ggml_time_ms();
Expand Down Expand Up @@ -339,3 +387,11 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
return 0;
}

bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
return true;
}
return false;
}
27 changes: 24 additions & 3 deletions examples/llava/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
struct mtmd_context_params {
bool use_gpu = true;
bool print_timings = true;
// calc_image_hash is useful for tracking KV cache
// if not set, mtmd_image_tokens_get_hash will return 0
bool calc_image_hash = false;
int n_threads = 4;
enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO;
const char * image_marker = "<__image__>";
Expand Down Expand Up @@ -81,13 +84,21 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
// 2. (image tokens)
// 3. "<end_of_image>\ndescribe it in detail."
// number of bitmaps must be equal to the number of image markers in the prompt
// the returned value must be freed using mtmd_input_chunks_free()
// this function is thread-safe (shared ctx)
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
const mtmd_input_text & text,
const std::vector<mtmd_bitmap> & bitmaps);

// free image chunk data
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
// if free_images = true, free the image tokens ; otherwise, you must free them using mtmd_image_free()
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images);

// access mtmd_image_tokens
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
MTMD_API uint64_t mtmd_image_tokens_get_hash(const mtmd_image_tokens * image_tokens);
MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);

// returns 0 on success
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
Expand All @@ -96,6 +107,11 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
// get output embeddings from the last encode pass
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);

// whether we need to set non-causal mask before llama_decode
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);



//
// helper functions (can be implemented based on other functions)
//
Expand Down Expand Up @@ -133,10 +149,15 @@ struct mtmd_context_deleter {
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;

struct mtmd_input_chunks_deleter {
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val, true); }
};
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;

struct mtmd_image_tokens_deleter {
void operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); }
};
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;

#else

static_assert(false && "C header is not yet supported by this library");
Expand Down
Loading