Skip to content

Commit a634d75

Browse files
authored
mtmd : move helpers to dedicated file (ggml-org#13442)
* mtmd : move helpers to dedicated file * fix windows build * rm redundant include
1 parent 62d4250 commit a634d75

File tree

4 files changed

+326
-312
lines changed

4 files changed

+326
-312
lines changed

tools/mtmd/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ endif()
2828

2929
add_library(mtmd OBJECT
3030
mtmd.cpp
31+
mtmd-helper.cpp
3132
mtmd.h
3233
clip.cpp
3334
clip.h

tools/mtmd/mtmd-helper.cpp

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
#include "mtmd.h"
2+
#include "llama.h"
3+
4+
#include <algorithm>
5+
#include <cinttypes>
6+
#include <vector>
7+
8+
#define LOG_INF(...) fprintf(stdout, __VA_ARGS__)
9+
#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__)
10+
11+
size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
12+
size_t n_tokens = 0;
13+
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
14+
auto chunk = mtmd_input_chunks_get(chunks, i);
15+
auto chunk_type = mtmd_input_chunk_get_type(chunk);
16+
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
17+
size_t n_tokens_text;
18+
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
19+
n_tokens += n_tokens_text;
20+
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
21+
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
22+
n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
23+
} else {
24+
GGML_ASSERT(false && "chunk type not supported");
25+
}
26+
}
27+
return n_tokens;
28+
}
29+
30+
llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
31+
llama_pos n_pos = 0;
32+
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
33+
auto chunk = mtmd_input_chunks_get(chunks, i);
34+
auto chunk_type = mtmd_input_chunk_get_type(chunk);
35+
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
36+
size_t n_tokens_text;
37+
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
38+
n_pos += n_tokens_text;
39+
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
40+
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
41+
n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
42+
} else {
43+
GGML_ASSERT(false && "chunk type not supported");
44+
}
45+
}
46+
return n_pos;
47+
}
48+
49+
// helper struct to make working with embd batch easier
50+
// note: this will be removed after llama_batch_ext refactoring
51+
struct decode_embd_batch {
52+
int n_pos_per_embd;
53+
int n_mmproj_embd;
54+
std::vector<llama_pos> pos;
55+
std::vector<llama_pos> pos_view; // used by mrope
56+
std::vector<int32_t> n_seq_id;
57+
std::vector<llama_seq_id> seq_id_0;
58+
std::vector<llama_seq_id *> seq_ids;
59+
std::vector<int8_t> logits;
60+
llama_batch batch;
61+
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
62+
pos .resize(n_tokens * n_pos_per_embd);
63+
n_seq_id.resize(n_tokens);
64+
seq_ids .resize(n_tokens + 1);
65+
logits .resize(n_tokens);
66+
seq_id_0.resize(1);
67+
seq_ids [n_tokens] = nullptr;
68+
batch = {
69+
/*n_tokens =*/ n_tokens,
70+
/*tokens =*/ nullptr,
71+
/*embd =*/ embd,
72+
/*pos =*/ pos.data(),
73+
/*n_seq_id =*/ n_seq_id.data(),
74+
/*seq_id =*/ seq_ids.data(),
75+
/*logits =*/ logits.data(),
76+
};
77+
}
78+
79+
void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) {
80+
seq_id_0[0] = seq_id;
81+
for (int i = 0; i < batch.n_tokens; i++) {
82+
batch.pos [i] = pos_0 + i;
83+
batch.n_seq_id[i] = 1;
84+
batch.seq_id [i] = seq_id_0.data();
85+
batch.logits [i] = false;
86+
}
87+
}
88+
89+
void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
90+
GGML_ASSERT(n_pos_per_embd == 4);
91+
seq_id_0[0] = seq_id;
92+
for (int y = 0; y < ny; y++) {
93+
for (int x = 0; x < nx; x++) {
94+
int i = y * nx + x;
95+
pos[i ] = pos_0;
96+
pos[i + batch.n_tokens ] = pos_0 + y;
97+
pos[i + batch.n_tokens * 2] = pos_0 + x;
98+
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
99+
}
100+
}
101+
for (int i = 0; i < batch.n_tokens; i++) {
102+
batch.n_seq_id[i] = 1;
103+
batch.seq_id [i] = seq_id_0.data();
104+
batch.logits [i] = false;
105+
}
106+
}
107+
108+
llama_batch get_view(int offset, int n_tokens) {
109+
llama_pos * pos_ptr;
110+
pos_view.clear();
111+
pos_view.reserve(n_tokens * n_pos_per_embd);
112+
if (n_pos_per_embd > 1) {
113+
// mrope
114+
// for example, with layout of src: 1234...1234...1234...1234...
115+
// offset 2 will give us dst: 34...34...34...34...
116+
for (int i = 0; i < n_pos_per_embd; i++) {
117+
// assume n_tokens is less than or equal to batch.n_tokens
118+
// batch.n_tokens is number of **total** tokens
119+
// n_tokens is number of viewed token
120+
size_t src_idx = i * batch.n_tokens + offset;
121+
pos_view.insert(pos_view.end(),
122+
pos.data() + src_idx,
123+
pos.data() + src_idx + n_tokens);
124+
}
125+
pos_ptr = pos_view.data();
126+
} else {
127+
// normal
128+
pos_ptr = pos.data() + offset;
129+
}
130+
return {
131+
/*n_tokens =*/ n_tokens,
132+
/*tokens =*/ nullptr,
133+
/*embd =*/ batch.embd + offset * n_mmproj_embd,
134+
/*pos =*/ pos_ptr,
135+
/*n_seq_id =*/ batch.n_seq_id + offset,
136+
/*seq_id =*/ batch.seq_id + offset,
137+
/*logits =*/ batch.logits + offset,
138+
};
139+
}
140+
};
141+
142+
// Helper function for decoding an image whose embeddings have already been calculated
143+
int32_t mtmd_helper_decode_image_chunk(
144+
mtmd_context * ctx,
145+
struct llama_context * lctx,
146+
const mtmd_input_chunk * chunk,
147+
float * encoded_embd,
148+
llama_pos n_past,
149+
llama_seq_id seq_id,
150+
int32_t n_batch,
151+
llama_pos * new_n_past) {
152+
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
153+
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
154+
return -1;
155+
}
156+
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
157+
if (!image_tokens) {
158+
LOG_ERR("failed to decode image chunk: image tokens are null\n");
159+
return -1;
160+
}
161+
162+
const llama_model * model = llama_get_model(lctx);
163+
int n_mmproj_embd = llama_model_n_embd(model);
164+
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
165+
166+
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
167+
int32_t i_batch = 0;
168+
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
169+
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
170+
171+
const int nx = mtmd_image_tokens_get_nx(image_tokens);
172+
const int ny = mtmd_image_tokens_get_ny(image_tokens);
173+
174+
if (mtmd_decode_use_mrope(ctx)) {
175+
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
176+
} else {
177+
batch_embd.set_position_normal(n_past, seq_id);
178+
}
179+
180+
if (mtmd_decode_use_non_causal(ctx)) {
181+
llama_set_causal_attn(lctx, false);
182+
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
183+
}
184+
185+
while (i_batch < n_img_batches) { // split into batches
186+
int pos_offset = i_batch*n_batch;
187+
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
188+
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
189+
190+
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
191+
192+
int64_t t1 = ggml_time_ms();
193+
int32_t ret = llama_decode(lctx, batch_embd_view);
194+
if (ret != 0) {
195+
LOG_ERR("failed to decode image\n");
196+
llama_set_causal_attn(lctx, true); // restore causal attn
197+
return ret;
198+
}
199+
200+
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
201+
202+
i_batch++;
203+
}
204+
205+
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
206+
*new_n_past = n_past;
207+
208+
if (mtmd_decode_use_non_causal(ctx)) {
209+
llama_set_causal_attn(lctx, true);
210+
}
211+
return 0;
212+
}
213+
214+
int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
215+
struct llama_context * lctx,
216+
const mtmd_input_chunk * chunk,
217+
llama_pos n_past,
218+
llama_seq_id seq_id,
219+
int32_t n_batch,
220+
bool logits_last,
221+
llama_pos * new_n_past) {
222+
int32_t ret;
223+
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
224+
auto chunk_type = mtmd_input_chunk_get_type(chunk);
225+
226+
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
227+
size_t n_tokens;
228+
const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
229+
// LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens);
230+
size_t i = 0;
231+
while (i < n_tokens) { // split into batches
232+
text_batch.n_tokens = 0; // clear the batch
233+
for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
234+
text_batch.n_tokens++;
235+
text_batch.token [i] = tokens[i];
236+
text_batch.pos [i] = n_past++;
237+
text_batch.n_seq_id[i] = 1;
238+
text_batch.seq_id [i][0] = seq_id;
239+
text_batch.logits [i] = false;
240+
}
241+
bool is_last_token = (i == n_tokens);
242+
if (logits_last && is_last_token) {
243+
text_batch.logits[text_batch.n_tokens - 1] = true;
244+
}
245+
ret = llama_decode(lctx, text_batch);
246+
if (ret != 0) {
247+
LOG_ERR("failed to decode text\n");
248+
llama_batch_free(text_batch);
249+
return ret;
250+
}
251+
*new_n_past += text_batch.n_tokens;
252+
}
253+
254+
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
255+
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
256+
int64_t t0 = ggml_time_ms();
257+
258+
LOG_INF("encoding image or slice...\n");
259+
260+
ret = mtmd_encode(ctx, image_tokens);
261+
if (ret != 0) {
262+
LOG_ERR("failed to encode image\n");
263+
llama_batch_free(text_batch);
264+
return ret;
265+
}
266+
267+
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
268+
269+
float * embd = mtmd_get_output_embd(ctx);
270+
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
271+
if (ret != 0) {
272+
LOG_ERR("failed to decode image\n");
273+
llama_batch_free(text_batch);
274+
return ret;
275+
}
276+
} else {
277+
GGML_ABORT("chunk type not supported");
278+
}
279+
280+
return 0;
281+
}
282+
283+
int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
284+
struct llama_context * lctx,
285+
const mtmd_input_chunks * chunks,
286+
llama_pos n_past,
287+
llama_seq_id seq_id,
288+
int32_t n_batch,
289+
bool logits_last,
290+
llama_pos * new_n_past) {
291+
size_t n_chunks = mtmd_input_chunks_size(chunks);
292+
if (n_chunks == 0) {
293+
LOG_ERR("no chunks to eval\n");
294+
return 0;
295+
}
296+
297+
for (size_t i = 0; i < n_chunks; i++) {
298+
bool chunk_logits_last = (i == n_chunks - 1) && logits_last;
299+
auto chunk = mtmd_input_chunks_get(chunks, i);
300+
301+
int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past);
302+
if (res != 0) {
303+
LOG_ERR("failed to eval chunk %zu\n", i);
304+
return res;
305+
}
306+
*new_n_past = n_past;
307+
}
308+
309+
return 0;
310+
}

0 commit comments

Comments
 (0)