|
| 1 | +#include "mtmd.h" |
| 2 | +#include "llama.h" |
| 3 | + |
| 4 | +#include <algorithm> |
| 5 | +#include <cinttypes> |
| 6 | +#include <vector> |
| 7 | + |
| 8 | +#define LOG_INF(...) fprintf(stdout, __VA_ARGS__) |
| 9 | +#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__) |
| 10 | + |
| 11 | +size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) { |
| 12 | + size_t n_tokens = 0; |
| 13 | + for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { |
| 14 | + auto chunk = mtmd_input_chunks_get(chunks, i); |
| 15 | + auto chunk_type = mtmd_input_chunk_get_type(chunk); |
| 16 | + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { |
| 17 | + size_t n_tokens_text; |
| 18 | + mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); |
| 19 | + n_tokens += n_tokens_text; |
| 20 | + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { |
| 21 | + auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); |
| 22 | + n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image); |
| 23 | + } else { |
| 24 | + GGML_ASSERT(false && "chunk type not supported"); |
| 25 | + } |
| 26 | + } |
| 27 | + return n_tokens; |
| 28 | +} |
| 29 | + |
| 30 | +llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) { |
| 31 | + llama_pos n_pos = 0; |
| 32 | + for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { |
| 33 | + auto chunk = mtmd_input_chunks_get(chunks, i); |
| 34 | + auto chunk_type = mtmd_input_chunk_get_type(chunk); |
| 35 | + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { |
| 36 | + size_t n_tokens_text; |
| 37 | + mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); |
| 38 | + n_pos += n_tokens_text; |
| 39 | + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { |
| 40 | + auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); |
| 41 | + n_pos += mtmd_image_tokens_get_n_pos(tokens_image); |
| 42 | + } else { |
| 43 | + GGML_ASSERT(false && "chunk type not supported"); |
| 44 | + } |
| 45 | + } |
| 46 | + return n_pos; |
| 47 | +} |
| 48 | + |
| 49 | +// helper struct to make working with embd batch easier |
| 50 | +// note: this will be removed after llama_batch_ext refactoring |
| 51 | +struct decode_embd_batch { |
| 52 | + int n_pos_per_embd; |
| 53 | + int n_mmproj_embd; |
| 54 | + std::vector<llama_pos> pos; |
| 55 | + std::vector<llama_pos> pos_view; // used by mrope |
| 56 | + std::vector<int32_t> n_seq_id; |
| 57 | + std::vector<llama_seq_id> seq_id_0; |
| 58 | + std::vector<llama_seq_id *> seq_ids; |
| 59 | + std::vector<int8_t> logits; |
| 60 | + llama_batch batch; |
| 61 | + decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) { |
| 62 | + pos .resize(n_tokens * n_pos_per_embd); |
| 63 | + n_seq_id.resize(n_tokens); |
| 64 | + seq_ids .resize(n_tokens + 1); |
| 65 | + logits .resize(n_tokens); |
| 66 | + seq_id_0.resize(1); |
| 67 | + seq_ids [n_tokens] = nullptr; |
| 68 | + batch = { |
| 69 | + /*n_tokens =*/ n_tokens, |
| 70 | + /*tokens =*/ nullptr, |
| 71 | + /*embd =*/ embd, |
| 72 | + /*pos =*/ pos.data(), |
| 73 | + /*n_seq_id =*/ n_seq_id.data(), |
| 74 | + /*seq_id =*/ seq_ids.data(), |
| 75 | + /*logits =*/ logits.data(), |
| 76 | + }; |
| 77 | + } |
| 78 | + |
| 79 | + void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) { |
| 80 | + seq_id_0[0] = seq_id; |
| 81 | + for (int i = 0; i < batch.n_tokens; i++) { |
| 82 | + batch.pos [i] = pos_0 + i; |
| 83 | + batch.n_seq_id[i] = 1; |
| 84 | + batch.seq_id [i] = seq_id_0.data(); |
| 85 | + batch.logits [i] = false; |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) { |
| 90 | + GGML_ASSERT(n_pos_per_embd == 4); |
| 91 | + seq_id_0[0] = seq_id; |
| 92 | + for (int y = 0; y < ny; y++) { |
| 93 | + for (int x = 0; x < nx; x++) { |
| 94 | + int i = y * nx + x; |
| 95 | + pos[i ] = pos_0; |
| 96 | + pos[i + batch.n_tokens ] = pos_0 + y; |
| 97 | + pos[i + batch.n_tokens * 2] = pos_0 + x; |
| 98 | + pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused |
| 99 | + } |
| 100 | + } |
| 101 | + for (int i = 0; i < batch.n_tokens; i++) { |
| 102 | + batch.n_seq_id[i] = 1; |
| 103 | + batch.seq_id [i] = seq_id_0.data(); |
| 104 | + batch.logits [i] = false; |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + llama_batch get_view(int offset, int n_tokens) { |
| 109 | + llama_pos * pos_ptr; |
| 110 | + pos_view.clear(); |
| 111 | + pos_view.reserve(n_tokens * n_pos_per_embd); |
| 112 | + if (n_pos_per_embd > 1) { |
| 113 | + // mrope |
| 114 | + // for example, with layout of src: 1234...1234...1234...1234... |
| 115 | + // offset 2 will give us dst: 34...34...34...34... |
| 116 | + for (int i = 0; i < n_pos_per_embd; i++) { |
| 117 | + // assume n_tokens is less than or equal to batch.n_tokens |
| 118 | + // batch.n_tokens is number of **total** tokens |
| 119 | + // n_tokens is number of viewed token |
| 120 | + size_t src_idx = i * batch.n_tokens + offset; |
| 121 | + pos_view.insert(pos_view.end(), |
| 122 | + pos.data() + src_idx, |
| 123 | + pos.data() + src_idx + n_tokens); |
| 124 | + } |
| 125 | + pos_ptr = pos_view.data(); |
| 126 | + } else { |
| 127 | + // normal |
| 128 | + pos_ptr = pos.data() + offset; |
| 129 | + } |
| 130 | + return { |
| 131 | + /*n_tokens =*/ n_tokens, |
| 132 | + /*tokens =*/ nullptr, |
| 133 | + /*embd =*/ batch.embd + offset * n_mmproj_embd, |
| 134 | + /*pos =*/ pos_ptr, |
| 135 | + /*n_seq_id =*/ batch.n_seq_id + offset, |
| 136 | + /*seq_id =*/ batch.seq_id + offset, |
| 137 | + /*logits =*/ batch.logits + offset, |
| 138 | + }; |
| 139 | + } |
| 140 | +}; |
| 141 | + |
| 142 | +// Helper function for decoding an image whose embeddings have already been calculated |
| 143 | +int32_t mtmd_helper_decode_image_chunk( |
| 144 | + mtmd_context * ctx, |
| 145 | + struct llama_context * lctx, |
| 146 | + const mtmd_input_chunk * chunk, |
| 147 | + float * encoded_embd, |
| 148 | + llama_pos n_past, |
| 149 | + llama_seq_id seq_id, |
| 150 | + int32_t n_batch, |
| 151 | + llama_pos * new_n_past) { |
| 152 | + if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) { |
| 153 | + LOG_ERR("failed to decode image chunk: input chunk not of image type\n"); |
| 154 | + return -1; |
| 155 | + } |
| 156 | + const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); |
| 157 | + if (!image_tokens) { |
| 158 | + LOG_ERR("failed to decode image chunk: image tokens are null\n"); |
| 159 | + return -1; |
| 160 | + } |
| 161 | + |
| 162 | + const llama_model * model = llama_get_model(lctx); |
| 163 | + int n_mmproj_embd = llama_model_n_embd(model); |
| 164 | + int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1; |
| 165 | + |
| 166 | + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); |
| 167 | + int32_t i_batch = 0; |
| 168 | + int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; |
| 169 | + decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd); |
| 170 | + |
| 171 | + const int nx = mtmd_image_tokens_get_nx(image_tokens); |
| 172 | + const int ny = mtmd_image_tokens_get_ny(image_tokens); |
| 173 | + |
| 174 | + if (mtmd_decode_use_mrope(ctx)) { |
| 175 | + batch_embd.set_position_mrope(n_past, nx, ny, seq_id); |
| 176 | + } else { |
| 177 | + batch_embd.set_position_normal(n_past, seq_id); |
| 178 | + } |
| 179 | + |
| 180 | + if (mtmd_decode_use_non_causal(ctx)) { |
| 181 | + llama_set_causal_attn(lctx, false); |
| 182 | + // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image |
| 183 | + } |
| 184 | + |
| 185 | + while (i_batch < n_img_batches) { // split into batches |
| 186 | + int pos_offset = i_batch*n_batch; |
| 187 | + int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset); |
| 188 | + llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch); |
| 189 | + |
| 190 | + LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch); |
| 191 | + |
| 192 | + int64_t t1 = ggml_time_ms(); |
| 193 | + int32_t ret = llama_decode(lctx, batch_embd_view); |
| 194 | + if (ret != 0) { |
| 195 | + LOG_ERR("failed to decode image\n"); |
| 196 | + llama_set_causal_attn(lctx, true); // restore causal attn |
| 197 | + return ret; |
| 198 | + } |
| 199 | + |
| 200 | + LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1); |
| 201 | + |
| 202 | + i_batch++; |
| 203 | + } |
| 204 | + |
| 205 | + n_past += mtmd_image_tokens_get_n_pos(image_tokens); |
| 206 | + *new_n_past = n_past; |
| 207 | + |
| 208 | + if (mtmd_decode_use_non_causal(ctx)) { |
| 209 | + llama_set_causal_attn(lctx, true); |
| 210 | + } |
| 211 | + return 0; |
| 212 | +} |
| 213 | + |
| 214 | +int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, |
| 215 | + struct llama_context * lctx, |
| 216 | + const mtmd_input_chunk * chunk, |
| 217 | + llama_pos n_past, |
| 218 | + llama_seq_id seq_id, |
| 219 | + int32_t n_batch, |
| 220 | + bool logits_last, |
| 221 | + llama_pos * new_n_past) { |
| 222 | + int32_t ret; |
| 223 | + llama_batch text_batch = llama_batch_init(n_batch, 0, 1); |
| 224 | + auto chunk_type = mtmd_input_chunk_get_type(chunk); |
| 225 | + |
| 226 | + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { |
| 227 | + size_t n_tokens; |
| 228 | + const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); |
| 229 | + // LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens); |
| 230 | + size_t i = 0; |
| 231 | + while (i < n_tokens) { // split into batches |
| 232 | + text_batch.n_tokens = 0; // clear the batch |
| 233 | + for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) { |
| 234 | + text_batch.n_tokens++; |
| 235 | + text_batch.token [i] = tokens[i]; |
| 236 | + text_batch.pos [i] = n_past++; |
| 237 | + text_batch.n_seq_id[i] = 1; |
| 238 | + text_batch.seq_id [i][0] = seq_id; |
| 239 | + text_batch.logits [i] = false; |
| 240 | + } |
| 241 | + bool is_last_token = (i == n_tokens); |
| 242 | + if (logits_last && is_last_token) { |
| 243 | + text_batch.logits[text_batch.n_tokens - 1] = true; |
| 244 | + } |
| 245 | + ret = llama_decode(lctx, text_batch); |
| 246 | + if (ret != 0) { |
| 247 | + LOG_ERR("failed to decode text\n"); |
| 248 | + llama_batch_free(text_batch); |
| 249 | + return ret; |
| 250 | + } |
| 251 | + *new_n_past += text_batch.n_tokens; |
| 252 | + } |
| 253 | + |
| 254 | + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { |
| 255 | + const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); |
| 256 | + int64_t t0 = ggml_time_ms(); |
| 257 | + |
| 258 | + LOG_INF("encoding image or slice...\n"); |
| 259 | + |
| 260 | + ret = mtmd_encode(ctx, image_tokens); |
| 261 | + if (ret != 0) { |
| 262 | + LOG_ERR("failed to encode image\n"); |
| 263 | + llama_batch_free(text_batch); |
| 264 | + return ret; |
| 265 | + } |
| 266 | + |
| 267 | + LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); |
| 268 | + |
| 269 | + float * embd = mtmd_get_output_embd(ctx); |
| 270 | + ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past); |
| 271 | + if (ret != 0) { |
| 272 | + LOG_ERR("failed to decode image\n"); |
| 273 | + llama_batch_free(text_batch); |
| 274 | + return ret; |
| 275 | + } |
| 276 | + } else { |
| 277 | + GGML_ABORT("chunk type not supported"); |
| 278 | + } |
| 279 | + |
| 280 | + return 0; |
| 281 | +} |
| 282 | + |
| 283 | +int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, |
| 284 | + struct llama_context * lctx, |
| 285 | + const mtmd_input_chunks * chunks, |
| 286 | + llama_pos n_past, |
| 287 | + llama_seq_id seq_id, |
| 288 | + int32_t n_batch, |
| 289 | + bool logits_last, |
| 290 | + llama_pos * new_n_past) { |
| 291 | + size_t n_chunks = mtmd_input_chunks_size(chunks); |
| 292 | + if (n_chunks == 0) { |
| 293 | + LOG_ERR("no chunks to eval\n"); |
| 294 | + return 0; |
| 295 | + } |
| 296 | + |
| 297 | + for (size_t i = 0; i < n_chunks; i++) { |
| 298 | + bool chunk_logits_last = (i == n_chunks - 1) && logits_last; |
| 299 | + auto chunk = mtmd_input_chunks_get(chunks, i); |
| 300 | + |
| 301 | + int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past); |
| 302 | + if (res != 0) { |
| 303 | + LOG_ERR("failed to eval chunk %zu\n", i); |
| 304 | + return res; |
| 305 | + } |
| 306 | + *new_n_past = n_past; |
| 307 | + } |
| 308 | + |
| 309 | + return 0; |
| 310 | +} |
0 commit comments