|
| 1 | +#include "arg.h" |
| 2 | +#include "log.h" |
| 3 | +#include "common.h" |
| 4 | +#include "sampling.h" |
| 5 | + |
| 6 | +#include "ggml.h" |
| 7 | +#include "ggml-cpp.h" |
| 8 | +#include "gguf.h" |
| 9 | + |
| 10 | +#include "whisper-preprocessor.h" |
| 11 | + |
| 12 | +static void show_additional_info(int /*argc*/, char ** argv) { |
| 13 | + LOG( |
| 14 | + "TODO\n\n" |
| 15 | + "Usage: %s [options] -m <model> --mmproj <mmproj> --in-file <image> -p <prompt>\n\n", |
| 16 | + argv[0] |
| 17 | + ); |
| 18 | +} |
| 19 | + |
| 20 | +struct hook_data { |
| 21 | + std::vector<float> embd; |
| 22 | + int n_token_output; |
| 23 | +}; |
| 24 | + |
| 25 | +// hook to retrieve the embeddings (because we cannot use arbitrary output tensor **shape**) |
| 26 | +static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { |
| 27 | + hook_data * data = (hook_data *) user_data; |
| 28 | + |
| 29 | + if (t && strcmp(t->name, "result_embd_pooled") == 0) { |
| 30 | + if (ask) return true; |
| 31 | + data->embd.resize(ggml_nelements(t)); |
| 32 | + data->n_token_output = t->ne[0]; |
| 33 | + ggml_backend_tensor_get(t, data->embd.data(), 0, ggml_nbytes(t)); |
| 34 | + printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]); |
| 35 | + return true; |
| 36 | + } |
| 37 | + |
| 38 | + return false; |
| 39 | +} |
| 40 | + |
| 41 | +int main(int argc, char ** argv) { |
| 42 | + ggml_time_init(); |
| 43 | + |
| 44 | + common_params params; |
| 45 | + params.prompt = "Transcribe the audio"; |
| 46 | + params.sampling.temp = 0.2; // lower temp by default for better quality |
| 47 | + |
| 48 | + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_ASR, show_additional_info)) { |
| 49 | + return 1; |
| 50 | + } |
| 51 | + |
| 52 | + common_init(); |
| 53 | + |
| 54 | + if (params.mmproj.path.empty()) { |
| 55 | + show_additional_info(argc, argv); |
| 56 | + return 1; |
| 57 | + } |
| 58 | + |
| 59 | + common_init_result ll_result = common_init_from_params(params); |
| 60 | + llama_model * ll_model = ll_result.model.get(); |
| 61 | + llama_context * ll_ctx = ll_result.context.get(); |
| 62 | + |
| 63 | + if (!ll_model || !ll_ctx) { |
| 64 | + LOG_ERR("Failed to initialize LLM\n"); |
| 65 | + return 1; |
| 66 | + } |
| 67 | + |
| 68 | + common_params params_enc(params); // copy |
| 69 | + params_enc.model.path = params.mmproj.path; |
| 70 | + params_enc.warmup = false; |
| 71 | + params_enc.n_ubatch = 1500; |
| 72 | + params_enc.n_batch = 1500; |
| 73 | + params_enc.embedding = true; |
| 74 | + |
| 75 | + hook_data hook_data; |
| 76 | + params_enc.cb_eval = ggml_callback; |
| 77 | + params_enc.cb_eval_user_data = &hook_data; |
| 78 | + |
| 79 | + common_init_result enc_result = common_init_from_params(params_enc); |
| 80 | + llama_model * enc_model = enc_result.model.get(); |
| 81 | + llama_context * enc_ctx = enc_result.context.get(); |
| 82 | + |
| 83 | + if (!enc_model || !enc_ctx) { |
| 84 | + LOG_ERR("Failed to initialize audio encoder model\n"); |
| 85 | + return 1; |
| 86 | + } |
| 87 | + |
| 88 | + // load mel_filters |
| 89 | + whisper_preprocessor::whisper_filters mel_filters; |
| 90 | + { |
| 91 | + ggml_context * meta = nullptr; |
| 92 | + gguf_init_params params = { |
| 93 | + /*.no_alloc = */ true, |
| 94 | + /*.ctx = */ &meta, |
| 95 | + }; |
| 96 | + gguf_context_ptr ctx_gguf(gguf_init_from_file(params_enc.model.path.c_str(), params)); |
| 97 | + |
| 98 | + // read size |
| 99 | + auto mel_filters_tensor = ggml_get_tensor(meta, "whisper.mel_filters"); |
| 100 | + mel_filters.n_mel = mel_filters_tensor->ne[1]; |
| 101 | + mel_filters.n_fft = mel_filters_tensor->ne[0]; |
| 102 | + mel_filters.data.resize(mel_filters.n_mel * mel_filters.n_fft); |
| 103 | + |
| 104 | + // read data |
| 105 | + auto idx = gguf_find_tensor(ctx_gguf.get(), "whisper.mel_filters"); |
| 106 | + auto offset = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), idx); |
| 107 | + auto size = gguf_get_tensor_size(ctx_gguf.get(), idx); |
| 108 | + auto fin = std::ifstream(params_enc.model.path, std::ios::binary); |
| 109 | + fin.seekg(offset, std::ios::beg); |
| 110 | + fin.read(reinterpret_cast<char *>(mel_filters.data.data()), size); |
| 111 | + fin.close(); |
| 112 | + |
| 113 | + printf("mel_filters: n_mel = %d, n_fft = %d\n", mel_filters.n_mel, mel_filters.n_fft); |
| 114 | + ggml_free(meta); |
| 115 | + } |
| 116 | + |
| 117 | + // read wav file |
| 118 | + std::vector<float> pcmf32; // mono-channel F32 PCM |
| 119 | + std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM |
| 120 | + auto fname_inp = params.in_files[0]; // TODO: support multiple files |
| 121 | + if (!wav_utils::read_wav(fname_inp, pcmf32, pcmf32s, false)) { |
| 122 | + fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); |
| 123 | + return 1; |
| 124 | + } |
| 125 | + |
| 126 | + // mel spectrogram |
| 127 | + whisper_preprocessor::whisper_mel mel; |
| 128 | + whisper_preprocessor::log_mel_spectrogram( |
| 129 | + pcmf32.data(), |
| 130 | + pcmf32.size(), |
| 131 | + WHISPER_SAMPLE_RATE, |
| 132 | + WHISPER_N_FFT, |
| 133 | + WHISPER_HOP_LENGTH, |
| 134 | + mel_filters.n_mel, |
| 135 | + 4, // threads |
| 136 | + mel_filters, |
| 137 | + false, |
| 138 | + mel); |
| 139 | + printf("mel.n_len: %d\n", mel.n_len); |
| 140 | + printf("mel.n_mel: %d\n", mel.n_mel); |
| 141 | + printf("mel.size: %zu\n", mel.data.size()); |
| 142 | + |
| 143 | + // encode audio |
| 144 | + { |
| 145 | + int n_ctx = llama_model_n_ctx_train(enc_model); |
| 146 | + int n_embd = llama_model_n_embd(enc_model); |
| 147 | + std::vector<float> embd(n_ctx * n_embd, 0.0f); |
| 148 | + // set the input |
| 149 | + { |
| 150 | + int mel_offset = 0; |
| 151 | + |
| 152 | + const int i0 = std::min(mel_offset, mel.n_len); |
| 153 | + const int i1 = std::min(mel_offset + 2*n_ctx, mel.n_len); |
| 154 | + |
| 155 | + for (int j = 0; j < mel.n_mel; ++j) { |
| 156 | + for (int i = i0; i < i1; ++i) { |
| 157 | + embd[j*2*n_ctx + (i - i0)] = mel.data[j*mel.n_len + i]; |
| 158 | + } |
| 159 | + } |
| 160 | + } |
| 161 | + |
| 162 | + // set the input |
| 163 | + llama_batch batch_embd = llama_batch_init(n_ctx, n_embd, 1); |
| 164 | + batch_embd.n_tokens = n_ctx; |
| 165 | + for (int i = 0; i < batch_embd.n_tokens; i++) { |
| 166 | + batch_embd.pos[i] = 0; // dummy, unused |
| 167 | + batch_embd.seq_id[i][0] = 0; |
| 168 | + batch_embd.n_seq_id[i] = 1; |
| 169 | + batch_embd.logits[i] = false; |
| 170 | + } |
| 171 | + std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float)); |
| 172 | + |
| 173 | + if (llama_decode(enc_ctx, batch_embd) != 0) { |
| 174 | + LOG_ERR("%s: audio llama_decode() failed\n", __func__); |
| 175 | + return 1; |
| 176 | + } |
| 177 | + |
| 178 | + // float * embd_out = hook_data.embd.data(); |
| 179 | + // print out the first 10 embeddings |
| 180 | + // for (int i = 0; i < 10; i++) { |
| 181 | + // printf("embd_out[%d] = %f\n", i, embd_out[i]); |
| 182 | + // } |
| 183 | + |
| 184 | + llama_batch_free(batch_embd); |
| 185 | + } |
| 186 | + |
| 187 | + // generate text |
| 188 | + { |
| 189 | + llama_batch batch_token = llama_batch_init(llama_n_ctx(ll_ctx), 0, 1); |
| 190 | + llama_batch batch_embd = llama_batch_init(hook_data.n_token_output, llama_model_n_embd(ll_model), 1); |
| 191 | + int n_past = 0; |
| 192 | + |
| 193 | + auto eval_text = [&](std::string text, bool add_bos = false) { |
| 194 | + llama_tokens prompt_tokens = common_tokenize(ll_ctx, text, add_bos, true); |
| 195 | + common_batch_clear(batch_token); |
| 196 | + for (auto & token : prompt_tokens) { |
| 197 | + common_batch_add(batch_token, token, n_past++, {0}, false); |
| 198 | + } |
| 199 | + if (!add_bos) { |
| 200 | + // TODO: a bit hacky here |
| 201 | + batch_token.logits[batch_token.n_tokens - 1] = true; |
| 202 | + } |
| 203 | + if (llama_decode(ll_ctx, batch_token) != 0) { |
| 204 | + LOG_ERR("%s: audio llama_decode() failed\n", __func__); |
| 205 | + exit(1); |
| 206 | + } |
| 207 | + }; |
| 208 | + |
| 209 | + auto eval_embd = [&](std::vector<float> & embd, int n_tokens) { |
| 210 | + batch_embd.n_tokens = n_tokens; |
| 211 | + for (int i = 0; i < n_tokens; i++) { |
| 212 | + batch_embd.pos[i] = n_past++; |
| 213 | + batch_embd.seq_id[i][0] = 0; |
| 214 | + batch_embd.n_seq_id[i] = 1; |
| 215 | + batch_embd.logits[i] = false; |
| 216 | + } |
| 217 | + std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float)); |
| 218 | + if (llama_decode(ll_ctx, batch_embd) != 0) { |
| 219 | + LOG_ERR("%s: audio llama_decode() failed\n", __func__); |
| 220 | + exit(1); |
| 221 | + } |
| 222 | + }; |
| 223 | + |
| 224 | + eval_text("<|start_header_id|>user<|end_header_id|>\n\n" + params.prompt + "\n\n", true); |
| 225 | + eval_embd(hook_data.embd, hook_data.n_token_output); |
| 226 | + eval_text("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); |
| 227 | + |
| 228 | + struct common_sampler * smpl = common_sampler_init(ll_model, params.sampling); |
| 229 | + |
| 230 | + int n_predict = 50; |
| 231 | + for (int i = 0; i < n_predict; i++) { |
| 232 | + llama_token token_id = common_sampler_sample(smpl, ll_ctx, -1); |
| 233 | + common_sampler_accept(smpl, token_id, true); |
| 234 | + |
| 235 | + if (llama_vocab_is_eog(llama_model_get_vocab(ll_model), token_id)) { |
| 236 | + printf("\n"); |
| 237 | + break; // end of generation |
| 238 | + } |
| 239 | + |
| 240 | + printf("%s", common_token_to_piece(ll_ctx, token_id).c_str()); |
| 241 | + fflush(stdout); |
| 242 | + |
| 243 | + // eval the token |
| 244 | + common_batch_clear(batch_token); |
| 245 | + common_batch_add(batch_token, token_id, n_past++, {0}, true); |
| 246 | + if (llama_decode(ll_ctx, batch_token)) { |
| 247 | + LOG_ERR("failed to decode token\n"); |
| 248 | + return 1; |
| 249 | + } |
| 250 | + } |
| 251 | + |
| 252 | + common_sampler_free(smpl); |
| 253 | + llama_batch_free(batch_token); |
| 254 | + llama_batch_free(batch_embd); |
| 255 | + } |
| 256 | +} |
0 commit comments