Skip to content

Commit 430dbd8

Browse files
committed
improve api
1 parent a6625fa commit 430dbd8

File tree

4 files changed

+111
-86
lines changed

4 files changed

+111
-86
lines changed

examples/llava/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ add_library(mtmd OBJECT
3737
target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
3838

3939
target_include_directories(mtmd PUBLIC .)
40-
target_include_directories(mtmd PUBLIC ../..)
41-
target_include_directories(mtmd PUBLIC ../../common) # for stb_image.h
40+
target_include_directories(mtmd PRIVATE ../..)
41+
target_include_directories(mtmd PRIVATE ../../common) # for stb_image.h
4242

4343
target_compile_features(mtmd PRIVATE cxx_std_17)
4444

examples/llava/gemma3-cli.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ struct gemma3_context {
8686

8787
void init_vision_context(common_params & params) {
8888
const char * clip_path = params.mmproj.path.c_str();
89-
ctx_vision = mtmd_init_from_file(clip_path, model, mtmd_context_params{
89+
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
9090
/* use_gpu */ true,
9191
/* timings */ true,
9292
/* n_threads */ params.cpuparams.n_threads,
9393
/* verbosity */ GGML_LOG_LEVEL_INFO,
94-
});
94+
}));
9595
if (!ctx_vision.get()) {
9696
LOG_ERR("Failed to load vision model from %s\n", clip_path);
9797
exit(1);
@@ -180,22 +180,22 @@ static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector
180180
bitmaps.push_back(std::move(bitmap));
181181
}
182182

183-
std::vector<mtmd_input_chunk> chunks;
184183
mtmd_input_text text;
185184
text.text = formatted_chat.prompt;
186185
text.add_special = add_bos;
187186
text.parse_special = true;
188-
if (mtmd_tokenize(ctx.ctx_vision, chunks, text, bitmaps)) {
187+
mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
188+
if (chunks == nullptr) {
189189
LOG_ERR("Unable to tokenize prompt\n");
190190
return 1;
191191
}
192192

193-
if (mtmd_helper_eval(ctx.ctx_vision, ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
193+
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
194194
LOG_ERR("Unable to eval prompt\n");
195195
return 1;
196196
}
197197

198-
ctx.n_past += mtmd_helper_get_n_tokens(chunks);
198+
ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
199199

200200
return 0;
201201
}

examples/llava/mtmd.cpp

Lines changed: 64 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,30 @@ struct mtmd_image_tokens_data {
4444
clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
4545
};
4646

47-
mtmd_context_ptr mtmd_init_from_file(const char * mmproj_fname,
47+
struct mtmd_image_tokens {
48+
uint32_t nx; // number of tokens in x direction
49+
uint32_t ny; // number of tokens in y direction
50+
uint32_t n_tokens() const { return nx * ny; }
51+
clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
52+
};
53+
54+
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
4855
const struct llama_model * text_model,
4956
const struct mtmd_context_params ctx_params) {
5057
try {
51-
auto ctx = std::make_shared<mtmd_context>(mmproj_fname, text_model, ctx_params);
52-
return ctx;
58+
return new mtmd_context(mmproj_fname, text_model, ctx_params);
5359
} catch (const std::exception & e) {
5460
LOG_ERR("%s: error: %s\n", __func__, e.what());
5561
return nullptr;
5662
}
5763
}
5864

65+
void mtmd_free(mtmd_context * ctx) {
66+
if (ctx) {
67+
delete ctx;
68+
}
69+
}
70+
5971
int32_t mtmd_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) {
6072
clip_image_u8_ptr img_u8(clip_image_u8_init());
6173
bool ok = clip_image_load_from_file(fname, img_u8.get());
@@ -89,10 +101,10 @@ static std::vector<llama_token> mtmd_tokenize_text_internal(
89101
return result;
90102
}
91103

92-
int32_t mtmd_tokenize(mtmd_context_ptr & ctx,
93-
std::vector<mtmd_input_chunk> & output,
94-
const mtmd_input_text & text,
95-
const std::vector<mtmd_bitmap> & bitmaps) {
104+
mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
105+
const mtmd_input_text & text,
106+
const std::vector<mtmd_bitmap> & bitmaps) {
107+
mtmd_input_chunks * output = new mtmd_input_chunks;
96108
auto vocab = llama_model_get_vocab(ctx->text_model);
97109

98110
std::string prompt_modified(text.text);
@@ -107,8 +119,8 @@ int32_t mtmd_tokenize(mtmd_context_ptr & ctx,
107119
}
108120

109121
std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
110-
output.clear();
111-
output.reserve(parts.size());
122+
output->clear();
123+
output->reserve(parts.size());
112124

113125
size_t i_img = 0;
114126

@@ -119,18 +131,19 @@ int32_t mtmd_tokenize(mtmd_context_ptr & ctx,
119131
if (tokens.empty()) {
120132
continue;
121133
}
122-
output.push_back({
123-
LLAVA2_INPUT_CHUNK_TYPE_TEXT,
134+
mtmd_input_chunk chunk{
135+
MTMD_INPUT_CHUNK_TYPE_TEXT,
124136
std::move(tokens),
125137
{},
126-
});
138+
};
139+
output->emplace_back(std::move(chunk));
127140

128141
if (&parts.back() != &part) {
129142
// add image token to middle of 2 parts
130143

131144
if (i_img >= bitmaps.size()) {
132145
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
133-
return 2;
146+
return nullptr;
134147
}
135148

136149
// shim layer
@@ -145,54 +158,58 @@ int32_t mtmd_tokenize(mtmd_context_ptr & ctx,
145158
bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), batch_f32.get());
146159
if (!ok) {
147160
LOG_ERR("Unable to preprocess image\n");
148-
return 1;
161+
return nullptr;
149162
}
150163

151-
mtmd_image_tokens image_tokens;
152-
image_tokens.nx = 0; // TODO
153-
image_tokens.ny = 0; // TODO
154-
image_tokens.n_tokens = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
155-
image_tokens.data = std::unique_ptr<mtmd_image_tokens_data>(
156-
new mtmd_image_tokens_data{
157-
std::move(batch_f32),
158-
}
159-
);
160-
161-
output.push_back({
162-
LLAVA2_INPUT_CHUNK_TYPE_IMAGE,
164+
mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
165+
image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
166+
image_tokens->ny = 1; // TODO
167+
image_tokens->batch_f32 = std::move(batch_f32);
168+
169+
mtmd_input_chunk chunk{
170+
MTMD_INPUT_CHUNK_TYPE_IMAGE,
163171
{},
164-
std::move(image_tokens),
165-
});
172+
image_tokens,
173+
};
174+
output->emplace_back(std::move(chunk));
166175
i_img++;
167176
}
168177
}
169178

170-
return 0;
179+
return output;
180+
}
181+
182+
void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
183+
for (auto & chunk : *chunks) {
184+
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
185+
delete chunk.tokens_image;
186+
}
187+
}
188+
delete chunks;
171189
}
172190

173-
LLAVA2_API int32_t mtmd_encode(mtmd_context_ptr & ctx,
174-
const mtmd_image_tokens & image_tokens) {
191+
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
175192
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
176-
ctx->image_embd_v.resize(image_tokens.n_tokens * n_mmproj_embd);
193+
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
177194
bool ok = clip_image_batch_encode(
178195
ctx->ctx_clip,
179196
ctx->n_threads,
180-
image_tokens.data->batch_f32.get(),
197+
image_tokens->batch_f32.get(),
181198
ctx->image_embd_v.data());
182199
return ok ? 0 : 1;
183200
}
184201

185-
LLAVA2_API float * mtmd_get_output_embd(mtmd_context_ptr & ctx) {
202+
float * mtmd_get_output_embd(mtmd_context * ctx) {
186203
return ctx->image_embd_v.data();
187204
}
188205

189-
size_t mtmd_helper_get_n_tokens(std::vector<mtmd_input_chunk> & chunks) {
206+
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
190207
size_t n_tokens = 0;
191-
for (auto & chunk : chunks) {
192-
if (chunk.type == LLAVA2_INPUT_CHUNK_TYPE_TEXT) {
208+
for (auto & chunk : *chunks) {
209+
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
193210
n_tokens += chunk.tokens_text.size();
194-
} else if (chunk.type == LLAVA2_INPUT_CHUNK_TYPE_IMAGE) {
195-
n_tokens += chunk.tokens_image.n_tokens;
211+
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
212+
n_tokens += chunk.tokens_image->n_tokens();
196213
} else {
197214
GGML_ASSERT(false && "chunk type not supported");
198215
}
@@ -235,19 +252,19 @@ struct decode_embd_batch {
235252
}
236253
};
237254

238-
int32_t mtmd_helper_eval(mtmd_context_ptr & ctx,
255+
int32_t mtmd_helper_eval(mtmd_context * ctx,
239256
llama_context * lctx,
240-
std::vector<mtmd_input_chunk> & chunks,
257+
mtmd_input_chunks * chunks,
241258
llama_pos pos0,
242259
llama_seq_id seq_id,
243260
int32_t n_batch) {
244261
int32_t ret;
245262
llama_pos n_past = pos0;
246263
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
247264

248-
for (auto & chunk : chunks) {
249-
bool is_last = &chunk == &chunks.back();
250-
if (chunk.type == LLAVA2_INPUT_CHUNK_TYPE_TEXT) {
265+
for (auto & chunk : *chunks) {
266+
bool is_last = &chunk == &chunks->back();
267+
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
251268
// TODO @ngxson : may need to split into smaller batches
252269
text_batch.n_tokens = chunk.tokens_text.size();
253270
for (size_t i = 0; i < chunk.tokens_text.size(); i++) {
@@ -268,8 +285,9 @@ int32_t mtmd_helper_eval(mtmd_context_ptr & ctx,
268285
return ret;
269286
}
270287

271-
} else if (chunk.type == LLAVA2_INPUT_CHUNK_TYPE_IMAGE) {
288+
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
272289
GGML_ASSERT(!is_last && "logits for last image chunk is not yet support");
290+
GGML_ASSERT(chunk.tokens_image != nullptr);
273291
int64_t t0 = ggml_time_ms();
274292
if (ctx->print_timings) {
275293
LOG_INF("encoding image...\n");
@@ -284,7 +302,7 @@ int32_t mtmd_helper_eval(mtmd_context_ptr & ctx,
284302
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
285303
}
286304

287-
int32_t n_tokens = chunk.tokens_image.n_tokens;
305+
int32_t n_tokens = chunk.tokens_image->n_tokens();
288306
float * embd = mtmd_get_output_embd(ctx);
289307
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
290308
int64_t t1 = ggml_time_ms();

0 commit comments

Comments
 (0)