|
| 1 | +#include "arg.h" |
| 2 | +#include "log.h" |
| 3 | +#include "common.h" |
| 4 | +#include "sampling.h" |
| 5 | +#include "clip.h" |
| 6 | +#include "stb_image.h" |
| 7 | +#include "llama.h" |
| 8 | +#include "ggml.h" |
| 9 | + |
| 10 | +#include <algorithm> |
| 11 | +#include <cstdio> |
| 12 | +#include <cstdlib> |
| 13 | +#include <cstring> |
| 14 | +#include <vector> |
| 15 | +#include <iostream> |
| 16 | +#include <fstream> |
| 17 | + |
| 18 | +struct phi4mm_context { |
| 19 | + struct clip_ctx * ctx_clip = NULL; |
| 20 | + common_init_result llama_init; |
| 21 | + |
| 22 | + llama_model * model; |
| 23 | + llama_context * lctx; |
| 24 | + llama_adapter_lora * vision_lora; |
| 25 | + |
| 26 | + phi4mm_context(common_params & params) : llama_init(common_init_from_params(params)) { |
| 27 | + model = llama_init.model.get(); |
| 28 | + lctx = llama_init.context.get(); |
| 29 | + vision_lora = llama_init.lora[0].get(); |
| 30 | + llama_clear_adapter_lora(lctx); |
| 31 | + init_clip_model(params); |
| 32 | + } |
| 33 | + |
| 34 | + void init_clip_model(common_params & params) { |
| 35 | + const char * clip_path = params.mmproj.c_str(); |
| 36 | + ctx_clip = clip_model_load(clip_path, params.verbosity > 1); |
| 37 | + } |
| 38 | + |
| 39 | + ~phi4mm_context() { |
| 40 | + clip_free(ctx_clip); |
| 41 | + } |
| 42 | +}; |
| 43 | + |
| 44 | +struct decode_embd_batch { |
| 45 | + std::vector<llama_pos> pos; |
| 46 | + std::vector<int32_t> n_seq_id; |
| 47 | + std::vector<llama_seq_id> seq_id_0; |
| 48 | + std::vector<llama_seq_id *> seq_ids; |
| 49 | + std::vector<int8_t> logits; |
| 50 | + llama_batch batch; |
| 51 | + decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { |
| 52 | + pos .resize(n_tokens); |
| 53 | + n_seq_id.resize(n_tokens); |
| 54 | + seq_ids .resize(n_tokens + 1); |
| 55 | + logits .resize(n_tokens); |
| 56 | + seq_id_0.resize(1); |
| 57 | + seq_id_0[0] = seq_id; |
| 58 | + seq_ids [n_tokens] = nullptr; |
| 59 | + batch = { |
| 60 | + /*n_tokens =*/ n_tokens, |
| 61 | + /*tokens =*/ nullptr, |
| 62 | + /*embd =*/ embd, |
| 63 | + /*pos =*/ pos.data(), |
| 64 | + /*n_seq_id =*/ n_seq_id.data(), |
| 65 | + /*seq_id =*/ seq_ids.data(), |
| 66 | + /*logits =*/ logits.data(), |
| 67 | + }; |
| 68 | + for (int i = 0; i < n_tokens; i++) { |
| 69 | + batch.pos [i] = pos_0 + i; |
| 70 | + batch.n_seq_id[i] = 1; |
| 71 | + batch.seq_id [i] = seq_id_0.data(); |
| 72 | + batch.logits [i] = false; |
| 73 | + } |
| 74 | + } |
| 75 | +}; |
| 76 | + |
| 77 | +struct inp_bitmap { |
| 78 | + int nx; |
| 79 | + int ny; |
| 80 | + std::vector<unsigned char> data; |
| 81 | +}; |
| 82 | + |
| 83 | +static void show_additional_info(int /*argc*/, char ** argv) { |
| 84 | + GGML_UNUSED(argv); |
| 85 | + LOG("TODO\n"); |
| 86 | +} |
| 87 | + |
| 88 | +static void eval_text(phi4mm_context & ctx, int & n_past, std::string input, bool logits_last = false) { |
| 89 | + llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); |
| 90 | + llama_batch batch = llama_batch_init(tokens.size(), 0, 1); |
| 91 | + for (llama_token & t : tokens) { |
| 92 | + common_batch_add(batch, t, n_past++, {0}, false); |
| 93 | + } |
| 94 | + if (logits_last) { |
| 95 | + batch.logits[batch.n_tokens - 1] = true; |
| 96 | + } |
| 97 | + LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); |
| 98 | + if (llama_decode(ctx.lctx, batch)) { |
| 99 | + GGML_ABORT("Failed to decode\n"); |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +int main(int argc, char ** argv) { |
| 104 | + ggml_time_init(); |
| 105 | + |
| 106 | + common_params params; |
| 107 | + |
| 108 | + // default values |
| 109 | + params.prompt = "<|user|>$what did you see?<|end|><|assistant|>"; |
| 110 | + params.n_predict = 64; |
| 111 | + params.sampling.temp = 0.0f; |
| 112 | + |
| 113 | + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) { |
| 114 | + return 1; |
| 115 | + } |
| 116 | + |
| 117 | + common_init(); |
| 118 | + |
| 119 | + if (params.mmproj.empty() || (params.image.empty())) { |
| 120 | + show_additional_info(argc, argv); |
| 121 | + return 1; |
| 122 | + } |
| 123 | + |
| 124 | + if (params.lora_adapters.empty()) { |
| 125 | + LOG_ERR("error: no vision lora adapters specified\n"); |
| 126 | + return 1; |
| 127 | + } |
| 128 | + |
| 129 | + phi4mm_context ctx(params); |
| 130 | + printf("%s: %s\n", __func__, params.model.c_str()); |
| 131 | + |
| 132 | + int n_threads = params.cpuparams.n_threads; |
| 133 | + int n_past = 0; |
| 134 | + |
| 135 | + std::vector<std::string> prompt_parts = string_split<std::string>(params.prompt, '$'); |
| 136 | + GGML_ASSERT(prompt_parts.size() == 2); |
| 137 | + eval_text(ctx, n_past, prompt_parts[0], false); |
| 138 | + |
| 139 | + // process images |
| 140 | + for (auto & image : params.image) { |
| 141 | + //break; |
| 142 | + std::vector<float> image_embd_v; |
| 143 | + int n_embd = llama_model_n_embd(ctx.model); |
| 144 | + int n_tokens = 256; |
| 145 | + image_embd_v.resize(n_tokens * n_embd); |
| 146 | + |
| 147 | + bool ok; |
| 148 | + struct clip_image_u8 * img_u8 = clip_image_u8_init(); |
| 149 | + ok = clip_image_load_from_file(image.c_str(), img_u8); |
| 150 | + if (!ok) { |
| 151 | + LOG_ERR("Unable to load image %s\n", image.c_str()); |
| 152 | + return 1; |
| 153 | + } |
| 154 | + |
| 155 | + clip_image_f32_batch batch_f32; |
| 156 | + ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32); |
| 157 | + if (!ok) { |
| 158 | + LOG_ERR("Unable to preprocess image\n"); |
| 159 | + return 1; |
| 160 | + } |
| 161 | + |
| 162 | + LOG("Encoding image %s\n", image.c_str()); |
| 163 | + ok = clip_image_batch_encode(ctx.ctx_clip, n_threads, &batch_f32, image_embd_v.data()); |
| 164 | + if (!ok) { |
| 165 | + LOG_ERR("Unable to encode image\n"); |
| 166 | + return 1; |
| 167 | + } |
| 168 | + |
| 169 | + // debug |
| 170 | + // for (int i = 0; i < 10; i++) { |
| 171 | + // LOG("embd[%d] = %f, %f, %f\n", i, image_embd_v[i*n_embd], image_embd_v[i*n_embd+1], image_embd_v[i*n_embd+2]); |
| 172 | + // } |
| 173 | + |
| 174 | + clip_image_f32_batch_free(&batch_f32); |
| 175 | + clip_image_u8_free(img_u8); |
| 176 | + |
| 177 | + // decode image embeddings |
| 178 | + llama_set_adapter_lora(ctx.lctx, ctx.vision_lora, 1.0f); |
| 179 | + decode_embd_batch batch_img(image_embd_v.data(), n_tokens, n_past, 0); |
| 180 | + if (llama_decode(ctx.lctx, batch_img.batch)) { |
| 181 | + LOG_ERR("failed to decode image\n"); |
| 182 | + return 1; |
| 183 | + } |
| 184 | + llama_clear_adapter_lora(ctx.lctx); |
| 185 | + n_past += n_tokens; |
| 186 | + } |
| 187 | + |
| 188 | + eval_text(ctx, n_past, prompt_parts[1], true); |
| 189 | + |
| 190 | + // generate text |
| 191 | + struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling); |
| 192 | + const llama_vocab * vocab = llama_model_get_vocab(ctx.model); |
| 193 | + int n_prompt = n_past; |
| 194 | + llama_batch batch = llama_batch_init(1, 0, 1); |
| 195 | + while (true) { |
| 196 | + int n_generated = n_past - n_prompt; |
| 197 | + if (n_generated > params.n_predict) { |
| 198 | + printf("\n"); |
| 199 | + break; |
| 200 | + } |
| 201 | + |
| 202 | + llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1); |
| 203 | + common_sampler_accept(smpl, token_id, true); |
| 204 | + printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str()); |
| 205 | + fflush(stdout); |
| 206 | + |
| 207 | + if (llama_vocab_is_eog(vocab, token_id)) { |
| 208 | + printf("\n"); |
| 209 | + break; |
| 210 | + } |
| 211 | + |
| 212 | + // eval the token |
| 213 | + common_batch_clear(batch); |
| 214 | + common_batch_add(batch, token_id, n_past++, {0}, true); |
| 215 | + if (llama_decode(ctx.lctx, batch)) { |
| 216 | + LOG_ERR("failed to decode token\n"); |
| 217 | + break; |
| 218 | + } |
| 219 | + } |
| 220 | + |
| 221 | + llama_batch_free(batch); |
| 222 | + |
| 223 | + return 0; |
| 224 | +} |
0 commit comments