Skip to content

Commit d44c721

Browse files
committed
output but wrong
1 parent 62695aa commit d44c721

File tree

15 files changed

+9792
-4
lines changed

15 files changed

+9792
-4
lines changed

common/arg.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14081408
}
14091409
params.in_files.push_back(value);
14101410
}
1411-
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
1411+
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_ASR}));
14121412
add_opt(common_arg(
14131413
{"-bf", "--binary-file"}, "FNAME",
14141414
"binary file containing the prompt (default: none)",
@@ -2094,14 +2094,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
20942094
[](common_params & params, const std::string & value) {
20952095
params.mmproj.path = value;
20962096
}
2097-
).set_examples({LLAMA_EXAMPLE_LLAVA}));
2097+
).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_ASR}));
20982098
add_opt(common_arg(
20992099
{"--mmproj-url"}, "URL",
21002100
"URL to a multimodal projector file for LLaVA. see examples/llava/README.md",
21012101
[](common_params & params, const std::string & value) {
21022102
params.mmproj.url = value;
21032103
}
2104-
).set_examples({LLAMA_EXAMPLE_LLAVA}));
2104+
).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_ASR}));
21052105
add_opt(common_arg(
21062106
{"--image"}, "FILE",
21072107
"path to an image file. use with multimodal models. Specify multiple times for batching",

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ enum llama_example {
8080
LLAMA_EXAMPLE_LOOKUP,
8181
LLAMA_EXAMPLE_PARALLEL,
8282
LLAMA_EXAMPLE_TTS,
83+
LLAMA_EXAMPLE_ASR,
8384

8485
LLAMA_EXAMPLE_COUNT,
8586
};

convert_hf_to_gguf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5297,6 +5297,9 @@ def set_gguf_parameters(self):
52975297
self.gguf_writer.add_layer_norm_eps(1e-5) # default from whisper
52985298
self.gguf_writer.add_block_count(audio_config["encoder_layers"])
52995299
self.gguf_writer.add_n_mel_bins(audio_config["num_mel_bins"])
5300+
self.gguf_writer.add_mm_stack_factor(self.hparams["stack_factor"])
5301+
self.gguf_writer.add_mm_embd_dim(self.hparams["hidden_size"])
5302+
self.gguf_writer.add_mm_output_dim(2048) # TODO: read from text_model_id
53005303
# We only have encoder, so we will always use non-causal attention
53015304
self.gguf_writer.add_causal_attention(False)
53025305

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ else()
5353
add_subdirectory(tokenize)
5454
add_subdirectory(tts)
5555
add_subdirectory(gen-docs)
56+
add_subdirectory(asr)
5657
if (NOT GGML_BACKEND_DL)
5758
# these examples use the backends directly and cannot be built with dynamic loading
5859
add_subdirectory(convert-llama2c-to-ggml)

examples/asr/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-asr-ultravox)
2+
add_executable(${TARGET} asr-ultravox.cpp whisper-preprocessor.h dr_wav.h)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_17)

examples/asr/asr-ultravox.cpp

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
#include "arg.h"
2+
#include "log.h"
3+
#include "common.h"
4+
#include "sampling.h"
5+
6+
#include "ggml.h"
7+
#include "ggml-cpp.h"
8+
#include "gguf.h"
9+
10+
#include "whisper-preprocessor.h"
11+
12+
static void show_additional_info(int /*argc*/, char ** argv) {
13+
LOG(
14+
"TODO\n\n"
15+
"Usage: %s [options] -m <model> --mmproj <mmproj> --in-file <image> -p <prompt>\n\n",
16+
argv[0]
17+
);
18+
}
19+
20+
struct hook_data {
21+
std::vector<float> embd;
22+
int n_token_output;
23+
};
24+
25+
// hook to retrieve the embeddings (because we cannot use arbitrary output tensor **shape**)
26+
static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) {
27+
hook_data * data = (hook_data *) user_data;
28+
29+
if (t && strcmp(t->name, "result_embd_pooled") == 0) {
30+
if (ask) return true;
31+
data->embd.resize(ggml_nelements(t));
32+
data->n_token_output = t->ne[0];
33+
ggml_backend_tensor_get(t, data->embd.data(), 0, ggml_nbytes(t));
34+
printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]);
35+
return true;
36+
}
37+
38+
return false;
39+
}
40+
41+
int main(int argc, char ** argv) {
42+
ggml_time_init();
43+
44+
common_params params;
45+
params.prompt = "Transcribe the audio";
46+
params.sampling.temp = 0.2; // lower temp by default for better quality
47+
48+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_ASR, show_additional_info)) {
49+
return 1;
50+
}
51+
52+
common_init();
53+
54+
if (params.mmproj.path.empty()) {
55+
show_additional_info(argc, argv);
56+
return 1;
57+
}
58+
59+
common_init_result ll_result = common_init_from_params(params);
60+
llama_model * ll_model = ll_result.model.get();
61+
llama_context * ll_ctx = ll_result.context.get();
62+
63+
if (!ll_model || !ll_ctx) {
64+
LOG_ERR("Failed to initialize LLM\n");
65+
return 1;
66+
}
67+
68+
common_params params_enc(params); // copy
69+
params_enc.model.path = params.mmproj.path;
70+
params_enc.warmup = false;
71+
params_enc.n_ubatch = 1500;
72+
params_enc.n_batch = 1500;
73+
params_enc.embedding = true;
74+
75+
hook_data hook_data;
76+
params_enc.cb_eval = ggml_callback;
77+
params_enc.cb_eval_user_data = &hook_data;
78+
79+
common_init_result enc_result = common_init_from_params(params_enc);
80+
llama_model * enc_model = enc_result.model.get();
81+
llama_context * enc_ctx = enc_result.context.get();
82+
83+
if (!enc_model || !enc_ctx) {
84+
LOG_ERR("Failed to initialize audio encoder model\n");
85+
return 1;
86+
}
87+
88+
// load mel_filters
89+
whisper_preprocessor::whisper_filters mel_filters;
90+
{
91+
ggml_context * meta = nullptr;
92+
gguf_init_params params = {
93+
/*.no_alloc = */ true,
94+
/*.ctx = */ &meta,
95+
};
96+
gguf_context_ptr ctx_gguf(gguf_init_from_file(params_enc.model.path.c_str(), params));
97+
98+
// read size
99+
auto mel_filters_tensor = ggml_get_tensor(meta, "whisper.mel_filters");
100+
mel_filters.n_mel = mel_filters_tensor->ne[1];
101+
mel_filters.n_fft = mel_filters_tensor->ne[0];
102+
mel_filters.data.resize(mel_filters.n_mel * mel_filters.n_fft);
103+
104+
// read data
105+
auto idx = gguf_find_tensor(ctx_gguf.get(), "whisper.mel_filters");
106+
auto offset = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), idx);
107+
auto size = gguf_get_tensor_size(ctx_gguf.get(), idx);
108+
auto fin = std::ifstream(params_enc.model.path, std::ios::binary);
109+
fin.seekg(offset, std::ios::beg);
110+
fin.read(reinterpret_cast<char *>(mel_filters.data.data()), size);
111+
fin.close();
112+
113+
printf("mel_filters: n_mel = %d, n_fft = %d\n", mel_filters.n_mel, mel_filters.n_fft);
114+
ggml_free(meta);
115+
}
116+
117+
// read wav file
118+
std::vector<float> pcmf32; // mono-channel F32 PCM
119+
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
120+
auto fname_inp = params.in_files[0]; // TODO: support multiple files
121+
if (!wav_utils::read_wav(fname_inp, pcmf32, pcmf32s, false)) {
122+
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
123+
return 1;
124+
}
125+
126+
// mel spectrogram
127+
whisper_preprocessor::whisper_mel mel;
128+
whisper_preprocessor::log_mel_spectrogram(
129+
pcmf32.data(),
130+
pcmf32.size(),
131+
WHISPER_SAMPLE_RATE,
132+
WHISPER_N_FFT,
133+
WHISPER_HOP_LENGTH,
134+
mel_filters.n_mel,
135+
4, // threads
136+
mel_filters,
137+
false,
138+
mel);
139+
printf("mel.n_len: %d\n", mel.n_len);
140+
printf("mel.n_mel: %d\n", mel.n_mel);
141+
printf("mel.size: %zu\n", mel.data.size());
142+
143+
// encode audio
144+
{
145+
int n_ctx = llama_model_n_ctx_train(enc_model);
146+
int n_embd = llama_model_n_embd(enc_model);
147+
std::vector<float> embd(n_ctx * n_embd, 0.0f);
148+
// set the input
149+
{
150+
int mel_offset = 0;
151+
152+
const int i0 = std::min(mel_offset, mel.n_len);
153+
const int i1 = std::min(mel_offset + 2*n_ctx, mel.n_len);
154+
155+
for (int j = 0; j < mel.n_mel; ++j) {
156+
for (int i = i0; i < i1; ++i) {
157+
embd[j*2*n_ctx + (i - i0)] = mel.data[j*mel.n_len + i];
158+
}
159+
}
160+
}
161+
162+
// set the input
163+
llama_batch batch_embd = llama_batch_init(n_ctx, n_embd, 1);
164+
batch_embd.n_tokens = n_ctx;
165+
for (int i = 0; i < batch_embd.n_tokens; i++) {
166+
batch_embd.pos[i] = 0; // dummy, unused
167+
batch_embd.seq_id[i][0] = 0;
168+
batch_embd.n_seq_id[i] = 1;
169+
batch_embd.logits[i] = false;
170+
}
171+
std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float));
172+
173+
if (llama_decode(enc_ctx, batch_embd) != 0) {
174+
LOG_ERR("%s: audio llama_decode() failed\n", __func__);
175+
return 1;
176+
}
177+
178+
// float * embd_out = hook_data.embd.data();
179+
// print out the first 10 embeddings
180+
// for (int i = 0; i < 10; i++) {
181+
// printf("embd_out[%d] = %f\n", i, embd_out[i]);
182+
// }
183+
184+
llama_batch_free(batch_embd);
185+
}
186+
187+
// generate text
188+
{
189+
llama_batch batch_token = llama_batch_init(llama_n_ctx(ll_ctx), 0, 1);
190+
llama_batch batch_embd = llama_batch_init(hook_data.n_token_output, llama_model_n_embd(ll_model), 1);
191+
int n_past = 0;
192+
193+
auto eval_text = [&](std::string text, bool add_bos = false) {
194+
llama_tokens prompt_tokens = common_tokenize(ll_ctx, text, add_bos, true);
195+
common_batch_clear(batch_token);
196+
for (auto & token : prompt_tokens) {
197+
common_batch_add(batch_token, token, n_past++, {0}, false);
198+
}
199+
if (!add_bos) {
200+
// TODO: a bit hacky here
201+
batch_token.logits[batch_token.n_tokens - 1] = true;
202+
}
203+
if (llama_decode(ll_ctx, batch_token) != 0) {
204+
LOG_ERR("%s: audio llama_decode() failed\n", __func__);
205+
exit(1);
206+
}
207+
};
208+
209+
auto eval_embd = [&](std::vector<float> & embd, int n_tokens) {
210+
batch_embd.n_tokens = n_tokens;
211+
for (int i = 0; i < n_tokens; i++) {
212+
batch_embd.pos[i] = n_past++;
213+
batch_embd.seq_id[i][0] = 0;
214+
batch_embd.n_seq_id[i] = 1;
215+
batch_embd.logits[i] = false;
216+
}
217+
std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float));
218+
if (llama_decode(ll_ctx, batch_embd) != 0) {
219+
LOG_ERR("%s: audio llama_decode() failed\n", __func__);
220+
exit(1);
221+
}
222+
};
223+
224+
eval_text("<|start_header_id|>user<|end_header_id|>\n\n" + params.prompt + "\n\n", true);
225+
eval_embd(hook_data.embd, hook_data.n_token_output);
226+
eval_text("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
227+
228+
struct common_sampler * smpl = common_sampler_init(ll_model, params.sampling);
229+
230+
int n_predict = 50;
231+
for (int i = 0; i < n_predict; i++) {
232+
llama_token token_id = common_sampler_sample(smpl, ll_ctx, -1);
233+
common_sampler_accept(smpl, token_id, true);
234+
235+
if (llama_vocab_is_eog(llama_model_get_vocab(ll_model), token_id)) {
236+
printf("\n");
237+
break; // end of generation
238+
}
239+
240+
printf("%s", common_token_to_piece(ll_ctx, token_id).c_str());
241+
fflush(stdout);
242+
243+
// eval the token
244+
common_batch_clear(batch_token);
245+
common_batch_add(batch_token, token_id, n_past++, {0}, true);
246+
if (llama_decode(ll_ctx, batch_token)) {
247+
LOG_ERR("failed to decode token\n");
248+
return 1;
249+
}
250+
}
251+
252+
common_sampler_free(smpl);
253+
llama_batch_free(batch_token);
254+
llama_batch_free(batch_embd);
255+
}
256+
}

0 commit comments

Comments
 (0)