Skip to content

Commit 0ce5415

Browse files
committed
mtmd : move helpers to dedicated file
1 parent 43dfd74 commit 0ce5415

File tree

3 files changed

+331
-312
lines changed

3 files changed

+331
-312
lines changed

tools/mtmd/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ endif()
2828

2929
add_library(mtmd OBJECT
3030
mtmd.cpp
31+
mtmd-helper.cpp
3132
mtmd.h
3233
clip.cpp
3334
clip.h

tools/mtmd/mtmd-helper.cpp

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
#include "clip.h"
2+
#include "mtmd.h"
3+
4+
#include "llama.h"
5+
6+
#include <algorithm>
7+
#include <cerrno>
8+
#include <cstdio>
9+
#include <cstdlib>
10+
#include <cstring>
11+
#include <limits>
12+
#include <vector>
13+
14+
#define LOG_INF(...) fprintf(stdout, __VA_ARGS__)
15+
#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__)
16+
17+
size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
18+
size_t n_tokens = 0;
19+
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
20+
auto chunk = mtmd_input_chunks_get(chunks, i);
21+
auto chunk_type = mtmd_input_chunk_get_type(chunk);
22+
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
23+
size_t n_tokens_text;
24+
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
25+
n_tokens += n_tokens_text;
26+
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
27+
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
28+
n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
29+
} else {
30+
GGML_ASSERT(false && "chunk type not supported");
31+
}
32+
}
33+
return n_tokens;
34+
}
35+
36+
llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
37+
llama_pos n_pos = 0;
38+
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
39+
auto chunk = mtmd_input_chunks_get(chunks, i);
40+
auto chunk_type = mtmd_input_chunk_get_type(chunk);
41+
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
42+
size_t n_tokens_text;
43+
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
44+
n_pos += n_tokens_text;
45+
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
46+
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
47+
n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
48+
} else {
49+
GGML_ASSERT(false && "chunk type not supported");
50+
}
51+
}
52+
return n_pos;
53+
}
54+
55+
// helper struct to make working with embd batch easier
56+
// note: this will be removed after llama_batch_ext refactoring
57+
struct decode_embd_batch {
58+
int n_pos_per_embd;
59+
int n_mmproj_embd;
60+
std::vector<llama_pos> pos;
61+
std::vector<llama_pos> pos_view; // used by mrope
62+
std::vector<int32_t> n_seq_id;
63+
std::vector<llama_seq_id> seq_id_0;
64+
std::vector<llama_seq_id *> seq_ids;
65+
std::vector<int8_t> logits;
66+
llama_batch batch;
67+
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
68+
pos .resize(n_tokens * n_pos_per_embd);
69+
n_seq_id.resize(n_tokens);
70+
seq_ids .resize(n_tokens + 1);
71+
logits .resize(n_tokens);
72+
seq_id_0.resize(1);
73+
seq_ids [n_tokens] = nullptr;
74+
batch = {
75+
/*n_tokens =*/ n_tokens,
76+
/*tokens =*/ nullptr,
77+
/*embd =*/ embd,
78+
/*pos =*/ pos.data(),
79+
/*n_seq_id =*/ n_seq_id.data(),
80+
/*seq_id =*/ seq_ids.data(),
81+
/*logits =*/ logits.data(),
82+
};
83+
}
84+
85+
void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) {
86+
seq_id_0[0] = seq_id;
87+
for (int i = 0; i < batch.n_tokens; i++) {
88+
batch.pos [i] = pos_0 + i;
89+
batch.n_seq_id[i] = 1;
90+
batch.seq_id [i] = seq_id_0.data();
91+
batch.logits [i] = false;
92+
}
93+
}
94+
95+
void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
96+
GGML_ASSERT(n_pos_per_embd == 4);
97+
seq_id_0[0] = seq_id;
98+
for (int y = 0; y < ny; y++) {
99+
for (int x = 0; x < nx; x++) {
100+
int i = y * nx + x;
101+
pos[i ] = pos_0;
102+
pos[i + batch.n_tokens ] = pos_0 + y;
103+
pos[i + batch.n_tokens * 2] = pos_0 + x;
104+
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
105+
}
106+
}
107+
for (int i = 0; i < batch.n_tokens; i++) {
108+
batch.n_seq_id[i] = 1;
109+
batch.seq_id [i] = seq_id_0.data();
110+
batch.logits [i] = false;
111+
}
112+
}
113+
114+
llama_batch get_view(int offset, int n_tokens) {
115+
llama_pos * pos_ptr;
116+
pos_view.clear();
117+
pos_view.reserve(n_tokens * n_pos_per_embd);
118+
if (n_pos_per_embd > 1) {
119+
// mrope
120+
// for example, with layout of src: 1234...1234...1234...1234...
121+
// offset 2 will give us dst: 34...34...34...34...
122+
for (int i = 0; i < n_pos_per_embd; i++) {
123+
// assume n_tokens is less than or equal to batch.n_tokens
124+
// batch.n_tokens is number of **total** tokens
125+
// n_tokens is number of viewed token
126+
size_t src_idx = i * batch.n_tokens + offset;
127+
pos_view.insert(pos_view.end(),
128+
pos.data() + src_idx,
129+
pos.data() + src_idx + n_tokens);
130+
}
131+
pos_ptr = pos_view.data();
132+
} else {
133+
// normal
134+
pos_ptr = pos.data() + offset;
135+
}
136+
return {
137+
/*n_tokens =*/ n_tokens,
138+
/*tokens =*/ nullptr,
139+
/*embd =*/ batch.embd + offset * n_mmproj_embd,
140+
/*pos =*/ pos_ptr,
141+
/*n_seq_id =*/ batch.n_seq_id + offset,
142+
/*seq_id =*/ batch.seq_id + offset,
143+
/*logits =*/ batch.logits + offset,
144+
};
145+
}
146+
};
147+
148+
// Helper function for decoding an image whose embeddings have already been calculated
149+
int32_t mtmd_helper_decode_image_chunk(
150+
mtmd_context * ctx,
151+
struct llama_context * lctx,
152+
const mtmd_input_chunk * chunk,
153+
float * encoded_embd,
154+
llama_pos n_past,
155+
llama_seq_id seq_id,
156+
int32_t n_batch,
157+
llama_pos * new_n_past) {
158+
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
159+
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
160+
return -1;
161+
}
162+
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
163+
if (!image_tokens) {
164+
LOG_ERR("failed to decode image chunk: image tokens are null\n");
165+
return -1;
166+
}
167+
168+
const llama_model * model = llama_get_model(lctx);
169+
int n_mmproj_embd = llama_model_n_embd(model);
170+
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
171+
172+
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
173+
int32_t i_batch = 0;
174+
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
175+
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
176+
177+
const int nx = mtmd_image_tokens_get_nx(image_tokens);
178+
const int ny = mtmd_image_tokens_get_ny(image_tokens);
179+
180+
if (mtmd_decode_use_mrope(ctx)) {
181+
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
182+
} else {
183+
batch_embd.set_position_normal(n_past, seq_id);
184+
}
185+
186+
if (mtmd_decode_use_non_causal(ctx)) {
187+
llama_set_causal_attn(lctx, false);
188+
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
189+
}
190+
191+
while (i_batch < n_img_batches) { // split into batches
192+
int pos_offset = i_batch*n_batch;
193+
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
194+
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
195+
196+
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
197+
198+
int64_t t1 = ggml_time_ms();
199+
int32_t ret = llama_decode(lctx, batch_embd_view);
200+
if (ret != 0) {
201+
LOG_ERR("failed to decode image\n");
202+
llama_set_causal_attn(lctx, true); // restore causal attn
203+
return ret;
204+
}
205+
206+
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
207+
208+
i_batch++;
209+
}
210+
211+
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
212+
*new_n_past = n_past;
213+
214+
if (mtmd_decode_use_non_causal(ctx)) {
215+
llama_set_causal_attn(lctx, true);
216+
}
217+
return 0;
218+
}
219+
220+
int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
221+
struct llama_context * lctx,
222+
const mtmd_input_chunk * chunk,
223+
llama_pos n_past,
224+
llama_seq_id seq_id,
225+
int32_t n_batch,
226+
bool logits_last,
227+
llama_pos * new_n_past) {
228+
int32_t ret;
229+
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
230+
auto chunk_type = mtmd_input_chunk_get_type(chunk);
231+
232+
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
233+
size_t n_tokens;
234+
const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
235+
// LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens);
236+
size_t i = 0;
237+
while (i < n_tokens) { // split into batches
238+
text_batch.n_tokens = 0; // clear the batch
239+
for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
240+
text_batch.n_tokens++;
241+
text_batch.token [i] = tokens[i];
242+
text_batch.pos [i] = n_past++;
243+
text_batch.n_seq_id[i] = 1;
244+
text_batch.seq_id [i][0] = seq_id;
245+
text_batch.logits [i] = false;
246+
}
247+
bool is_last_token = (i == n_tokens);
248+
if (logits_last && is_last_token) {
249+
text_batch.logits[text_batch.n_tokens - 1] = true;
250+
}
251+
ret = llama_decode(lctx, text_batch);
252+
if (ret != 0) {
253+
LOG_ERR("failed to decode text\n");
254+
llama_batch_free(text_batch);
255+
return ret;
256+
}
257+
*new_n_past += text_batch.n_tokens;
258+
}
259+
260+
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
261+
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
262+
int64_t t0 = ggml_time_ms();
263+
264+
LOG_INF("encoding image or slice...\n");
265+
266+
ret = mtmd_encode(ctx, image_tokens);
267+
if (ret != 0) {
268+
LOG_ERR("failed to encode image\n");
269+
llama_batch_free(text_batch);
270+
return ret;
271+
}
272+
273+
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
274+
275+
float * embd = mtmd_get_output_embd(ctx);
276+
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
277+
if (ret != 0) {
278+
LOG_ERR("failed to decode image\n");
279+
llama_batch_free(text_batch);
280+
return ret;
281+
}
282+
} else {
283+
GGML_ABORT("chunk type not supported");
284+
}
285+
286+
return 0;
287+
}
288+
289+
int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
290+
struct llama_context * lctx,
291+
const mtmd_input_chunks * chunks,
292+
llama_pos n_past,
293+
llama_seq_id seq_id,
294+
int32_t n_batch,
295+
bool logits_last,
296+
llama_pos * new_n_past) {
297+
size_t n_chunks = mtmd_input_chunks_size(chunks);
298+
if (n_chunks == 0) {
299+
LOG_ERR("no chunks to eval\n");
300+
return 0;
301+
}
302+
303+
for (size_t i = 0; i < n_chunks; i++) {
304+
bool chunk_logits_last = (i == n_chunks - 1) && logits_last;
305+
auto chunk = mtmd_input_chunks_get(chunks, i);
306+
307+
int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past);
308+
if (res != 0) {
309+
LOG_ERR("failed to eval chunk %zu\n", i);
310+
return res;
311+
}
312+
*new_n_past = n_past;
313+
}
314+
315+
return 0;
316+
}

0 commit comments

Comments
 (0)