Skip to content

Commit 3d78d2c

Browse files
committed
vad : add whisper_vad_init_with_params_no_state
This commit extracts the logic of loading the VAD model from a file into a separate function so that it is more inline with how the whisper model is loaded.
1 parent c065556 commit 3d78d2c

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

include/whisper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,9 @@ extern "C" {
694694
const char * path_model,
695695
const struct whisper_vad_context_params params);
696696

697+
WHISPER_API struct whisper_vad_context * whisper_vad_init_with_params_no_state(struct whisper_model_loader * loader,
698+
struct whisper_vad_context_params params);
699+
697700
struct whisper_vad_speech {
698701
int n_probs;
699702
float * probs;

src/whisper.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4840,47 +4840,55 @@ whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
48404840
fin->close();
48414841
};
48424842

4843-
// Read the VAD model TODO(danbev) Extract to separate function
4843+
auto ctx = whisper_vad_init_with_params_no_state(&loader, params);
4844+
if (ctx) {
4845+
ctx->path_model = path_model;
4846+
}
4847+
4848+
return ctx;
4849+
}
4850+
4851+
struct whisper_vad_context * whisper_vad_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_vad_context_params params) {
4852+
// Read the VAD model
48444853
{
48454854
uint32_t magic;
4846-
read_safe(&loader, magic);
4855+
read_safe(loader, magic);
48474856
if (magic != GGML_FILE_MAGIC) {
48484857
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
48494858
return nullptr;
48504859
}
48514860
}
48524861

48534862
whisper_vad_context * vctx = new whisper_vad_context;
4854-
vctx->path_model = path_model;
48554863
vctx->n_threads = params.n_threads;
48564864

48574865
auto & model = vctx->model;
48584866
auto & hparams = model.hparams;
48594867

48604868
// load model context params.
48614869
{
4862-
read_safe(&loader, vctx->n_window);
4863-
read_safe(&loader, vctx->n_context);
4870+
read_safe(loader, vctx->n_window);
4871+
read_safe(loader, vctx->n_context);
48644872
}
48654873

48664874
// load model hyper params (hparams).
48674875
{
4868-
read_safe(&loader, hparams.n_encoder_layers);
4876+
read_safe(loader, hparams.n_encoder_layers);
48694877

48704878
hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
48714879
hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
48724880
hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
48734881

48744882
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4875-
read_safe(&loader, hparams.encoder_in_channels[i]);
4876-
read_safe(&loader, hparams.encoder_out_channels[i]);
4877-
read_safe(&loader, hparams.kernel_sizes[i]);
4883+
read_safe(loader, hparams.encoder_in_channels[i]);
4884+
read_safe(loader, hparams.encoder_out_channels[i]);
4885+
read_safe(loader, hparams.kernel_sizes[i]);
48784886
}
48794887

4880-
read_safe(&loader, hparams.lstm_input_size);
4881-
read_safe(&loader, hparams.lstm_hidden_size);
4882-
read_safe(&loader, hparams.final_conv_in);
4883-
read_safe(&loader, hparams.final_conv_out);
4888+
read_safe(loader, hparams.lstm_input_size);
4889+
read_safe(loader, hparams.lstm_hidden_size);
4890+
read_safe(loader, hparams.final_conv_in);
4891+
read_safe(loader, hparams.final_conv_out);
48844892

48854893
WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
48864894
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
@@ -5067,24 +5075,24 @@ whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
50675075
int32_t length;
50685076
int32_t ttype;
50695077

5070-
read_safe(&loader, n_dims);
5071-
read_safe(&loader, length);
5072-
read_safe(&loader, ttype);
5078+
read_safe(loader, n_dims);
5079+
read_safe(loader, length);
5080+
read_safe(loader, ttype);
50735081

5074-
if (loader.eof(loader.context)) {
5082+
if (loader->eof(loader->context)) {
50755083
break;
50765084
}
50775085

50785086
int32_t nelements = 1;
50795087
int32_t ne[4] = { 1, 1, 1, 1 };
50805088
for (int i = 0; i < n_dims; ++i) {
5081-
read_safe(&loader, ne[i]);
5089+
read_safe(loader, ne[i]);
50825090
nelements *= ne[i];
50835091
}
50845092

50855093
std::string name;
50865094
std::vector<char> tmp(length); // create a buffer
5087-
loader.read(loader.context, &tmp[0], tmp.size()); // read to buffer
5095+
loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
50885096
name.assign(&tmp[0], tmp.size());
50895097

50905098
if (model.tensors.find(name) == model.tensors.end()) {
@@ -5117,13 +5125,13 @@ whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
51175125

51185126
if (ggml_backend_buffer_is_host(tensor->buffer)) {
51195127
// for the CPU and Metal backend, we can read directly into the tensor
5120-
loader.read(loader.context, tensor->data, ggml_nbytes(tensor));
5128+
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
51215129
BYTESWAP_TENSOR(tensor);
51225130
} else {
51235131
// read into a temporary buffer first, then copy to device memory
51245132
read_buf.resize(ggml_nbytes(tensor));
51255133

5126-
loader.read(loader.context, read_buf.data(), read_buf.size());
5134+
loader->read(loader->context, read_buf.data(), read_buf.size());
51275135

51285136
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
51295137
}

0 commit comments

Comments
 (0)