Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions examples/talk-llama/talk-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ struct whisper_params {
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
int32_t n_gpu_layers = 999;

int32_t seed = 0;
int32_t top_k = 5;
size_t min_keep = 1;
float top_p = 0.80f;
float min_p = 0.01f;
float temp = 0.30f;

float vad_thold = 0.6f;
float freq_thold = 100.0f;

Expand Down Expand Up @@ -102,6 +108,12 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
else if (arg == "--seed") { params.seed = std::stoi(argv[++i]); }
else if (arg == "--top-k") { params.top_k = std::stoi(argv[++i]); }
else if (arg == "--min-keep") { params.min_keep = std::stoul(argv[++i]);}
else if (arg == "--top-p") { params.top_p = std::stof(argv[++i]); }
else if (arg == "--min-p") { params.min_p = std::stof(argv[++i]); }
else if (arg == "--temp") { params.temp = std::stof(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
Expand Down Expand Up @@ -150,6 +162,12 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
fprintf(stderr, " --seed N [%-7d] seed sampling\n", params.seed);
fprintf(stderr, " --top-k N [%-7d] top-k sampling (0 = disabled)\n", params.top_k);
fprintf(stderr, " --min-keep N [%-7d] minimum number of tokens to keep\n", params.min_keep);
fprintf(stderr, " --top-p N [%-7.2f] top-p sampling\n", params.top_p);
fprintf(stderr, " --min-p N [%-7.2f] min-p sampling\n", params.min_p);
fprintf(stderr, " --temp N [%-7.2f] temperature\n", params.temp);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
Expand Down Expand Up @@ -409,21 +427,16 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);

// init sampler
const float top_k = 5;
const float top_p = 0.80f;
const float temp = 0.30f;

const int seed = 0;

auto sparams = llama_sampler_chain_default_params();

llama_sampler * smpl = llama_sampler_chain_init(sparams);

if (temp > 0.0f) {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(top_p, 1));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
if (params.temp > 0.0f) {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.top_p, params.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.seed));
llama_sampler_chain_add(smpl, llama_sampler_init_min_p (params.min_p, params.min_keep));
} else {
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
}
Expand Down Expand Up @@ -615,7 +628,7 @@ int main(int argc, char ** argv) {
}

// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9åäöÅÄÖ\\.,\\?!\\s\\:\\'\\-]"), "");

// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
Expand Down