Skip to content

Commit f6b6517

Browse files
committed
wip
1 parent 4a4f35c commit f6b6517

File tree

3 files changed

+97
-47
lines changed

3 files changed

+97
-47
lines changed

examples/llava/mtmd-cli.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@ struct mtmd_cli_context {
112112

113113
void init_vision_context(common_params & params) {
114114
const char * clip_path = params.mmproj.path.c_str();
115-
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
116-
/* use_gpu */ params.mmproj_use_gpu,
117-
/* timings */ true,
118-
/* n_threads */ params.cpuparams.n_threads,
119-
/* verbosity */ params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO,
120-
}));
115+
mtmd_context_params mparams = mtmd_context_params_default();
116+
mparams.use_gpu = params.mmproj_use_gpu;
117+
mparams.print_timings = true;
118+
mparams.n_threads = params.cpuparams.n_threads;
119+
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
120+
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
121121
if (!ctx_vision.get()) {
122122
LOG_ERR("Failed to load vision model from %s\n", clip_path);
123123
exit(1);
@@ -228,7 +228,7 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
228228
text.text = formatted_chat.prompt;
229229
text.add_special = add_bos;
230230
text.parse_special = true;
231-
mtmd_input_chunks chunks;
231+
std::vector<mtmd_input_chunk> chunks;
232232

233233
if (g_is_interrupted) return 0;
234234

examples/llava/mtmd.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ enum mtmd_slice_tmpl {
2121
// TODO @ngxson : add support for idefics (SmolVLM)
2222
};
2323

24+
mtmd_context_params mtmd_context_params_default() {
25+
mtmd_context_params params;
26+
params.use_gpu = true;
27+
params.print_timings = true;
28+
params.n_threads = 4;
29+
params.verbosity = GGML_LOG_LEVEL_INFO;
30+
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
31+
return params;
32+
}
33+
2434
struct mtmd_context {
2535
struct clip_ctx * ctx_clip;
2636
const struct llama_model * text_model;
@@ -411,7 +421,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
411421
return ctx->image_embd_v.data();
412422
}
413423

414-
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
424+
size_t mtmd_helper_get_n_tokens(std::vector<mtmd_input_chunk> & chunks) {
415425
size_t n_tokens = 0;
416426
for (auto & chunk : chunks) {
417427
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
@@ -462,7 +472,7 @@ struct decode_embd_batch {
462472

463473
int32_t mtmd_helper_eval(mtmd_context * ctx,
464474
llama_context * lctx,
465-
mtmd_input_chunks & chunks,
475+
std::vector<mtmd_input_chunk> & chunks,
466476
llama_pos pos0,
467477
llama_seq_id seq_id,
468478
int32_t n_batch) {

examples/llava/mtmd.h

Lines changed: 78 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
#include "llama.h"
66
#include "clip.h"
77

8+
#include <stddef.h>
9+
#include <stdint.h>
10+
#include <stdbool.h>
11+
12+
#ifdef __cplusplus
813
#include <vector>
914
#include <cinttypes>
1015
#include <memory>
16+
#endif
1117

1218
#ifdef LLAMA_SHARED
1319
# if defined(_WIN32) && !defined(__MINGW32__)
@@ -23,7 +29,7 @@
2329
# define MTMD_API
2430
#endif
2531

26-
#ifdef __cplusplus
32+
#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
2733

2834
enum mtmd_input_chunk_type {
2935
MTMD_INPUT_CHUNK_TYPE_TEXT,
@@ -33,6 +39,75 @@ enum mtmd_input_chunk_type {
3339
struct mtmd_context;
3440
struct mtmd_image_tokens;
3541

42+
//
43+
// C API
44+
// this is made to closely resemble the C++ API
45+
//
46+
47+
// forward declaration for C API (the actual struct is defined in C++)
48+
struct mtmd_bitmap;
49+
struct mtmd_input_chunk;
50+
51+
struct mtmd_context_params {
52+
bool use_gpu;
53+
bool print_timings;
54+
int n_threads;
55+
enum ggml_log_level verbosity;
56+
const char * image_marker;
57+
};
58+
59+
MTMD_API mtmd_context_params mtmd_context_params_default();
60+
61+
// initialize the mtmd context
62+
// return nullptr on failure
63+
MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
64+
const llama_model * text_model,
65+
const mtmd_context_params ctx_params);
66+
67+
MTMD_API void mtmd_free(mtmd_context * ctx);
68+
69+
// get output embeddings from the last encode pass
70+
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
71+
72+
// whether we need to set non-causal mask before llama_decode
73+
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
74+
75+
// mtmd_bitmap
76+
//
77+
// length of data must be nx * ny * 3
78+
// the data is in RGBRGBRGB... format
79+
// the id is optional (can be nullptr), but useful for KV cache tracking
80+
MTMD_API mtmd_bitmap * mtmd_bitmap_init(
81+
uint32_t nx,
82+
uint32_t ny,
83+
const unsigned char * data,
84+
const char * id, size_t id_len);
85+
MTMD_API uint32_t mtmd_bitmap_get_nx (mtmd_bitmap * bitmap);
86+
MTMD_API uint32_t mtmd_bitmap_get_ny (mtmd_bitmap * bitmap);
87+
MTMD_API const unsigned char * mtmd_bitmap_get_data(mtmd_bitmap * bitmap);
88+
MTMD_API const char * mtmd_bitmap_get_id (mtmd_bitmap * bitmap);
89+
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
90+
91+
// mtmd_input_chunk
92+
//
93+
// the instance can be constructed via mtmd_tokenize()
94+
MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk);
95+
MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
96+
MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
97+
MTMD_API void mtmd_input_chunk_free (mtmd_input_chunk * chunk);
98+
99+
100+
//
101+
// C++ API
102+
//
103+
104+
#ifdef __cplusplus
105+
106+
struct mtmd_context_deleter {
107+
void operator()(mtmd_context * val) { mtmd_free(val); }
108+
};
109+
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
110+
36111
// represents raw image data, layout is RGBRGBRGB...
37112
// length of data must be nx * ny * 3
38113
struct mtmd_bitmap {
@@ -53,30 +128,12 @@ struct mtmd_input_chunk {
53128
mtmd_image_tokens_ptr tokens_image;
54129
};
55130

56-
using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
57-
58-
struct mtmd_context_params {
59-
bool use_gpu = true;
60-
bool print_timings = true;
61-
int n_threads = 4;
62-
enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO;
63-
const char * image_marker = "<__image__>";
64-
};
65-
66131
struct mtmd_input_text {
67132
std::string text;
68133
bool add_special;
69134
bool parse_special;
70135
};
71136

72-
// initialize the mtmd context
73-
// return nullptr on failure
74-
MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
75-
const llama_model * text_model,
76-
const mtmd_context_params ctx_params);
77-
78-
MTMD_API void mtmd_free(mtmd_context * ctx);
79-
80137
// tokenize an input text prompt and an image
81138
// the prompt must have the input image marker (default: "<__image__>") in it
82139
// the marker will be replaced with the image tokens
@@ -108,20 +165,14 @@ MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);
108165
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
109166
const mtmd_image_tokens * image_tokens);
110167

111-
// get output embeddings from the last encode pass
112-
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
113-
114-
// whether we need to set non-causal mask before llama_decode
115-
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
116-
117168

118169

119170
//
120171
// helper functions (can be implemented based on other functions)
121172
//
122173

123174
// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
124-
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
175+
MTMD_API size_t mtmd_helper_get_n_tokens(std::vector<mtmd_input_chunk> & chunks);
125176

126177
// helper function that automatically:
127178
// 1. run llama_decode() on text chunks
@@ -130,7 +181,7 @@ MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
130181
// otherwise, returns 0 on success
131182
MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
132183
llama_context * lctx,
133-
mtmd_input_chunks & chunks,
184+
std::vector<mtmd_input_chunk> & chunks,
134185
llama_pos pos0,
135186
llama_seq_id seq_id,
136187
int32_t n_batch);
@@ -146,18 +197,7 @@ MTMD_API int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitm
146197
// this function is thread-safe
147198
MTMD_API int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output);
148199

149-
// convenient unique_ptr wrappers
150-
struct mtmd_context_deleter {
151-
void operator()(mtmd_context * val) { mtmd_free(val); }
152-
};
153-
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
154-
155200
#endif
156201

157-
//
158-
// C API
159-
//
160-
161-
162202

163203
#endif

0 commit comments

Comments
 (0)