@@ -60,7 +60,13 @@ struct whisper_params {
6060 int32_t max_tokens = 32 ;
6161 int32_t audio_ctx = 0 ;
6262 int32_t n_gpu_layers = 999 ;
63-
63+ int32_t seed = 0 ;
64+ int32_t top_k = 5 ;
65+ int32_t min_keep = 1 ;
66+ float top_p = 0 .80f ;
67+ float min_p = 0 .01f ;
68+ float temp = 0 .30f ;
69+
6470 float vad_thold = 0 .6f ;
6571 float freq_thold = 100 .0f ;
6672
@@ -102,6 +108,12 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
102108 else if (arg == " -mt" || arg == " --max-tokens" ) { params.max_tokens = std::stoi (argv[++i]); }
103109 else if (arg == " -ac" || arg == " --audio-ctx" ) { params.audio_ctx = std::stoi (argv[++i]); }
104110 else if (arg == " -ngl" || arg == " --n-gpu-layers" ) { params.n_gpu_layers = std::stoi (argv[++i]); }
111+ else if (arg == " --seed" ) { params.seed = std::stoi (argv[++i]); }
112+ else if (arg == " --top-k" ) { params.top_k = std::stoi (argv[++i]); }
113+ else if (arg == " --min-keep" ) { params.min_keep = std::stoul (argv[++i]);}
114+ else if (arg == " --top-p" ) { params.top_p = std::stof (argv[++i]); }
115+ else if (arg == " --min-p" ) { params.min_p = std::stof (argv[++i]); }
116+ else if (arg == " --temp" ) { params.temp = std::stof (argv[++i]); }
105117 else if (arg == " -vth" || arg == " --vad-thold" ) { params.vad_thold = std::stof (argv[++i]); }
106118 else if (arg == " -fth" || arg == " --freq-thold" ) { params.freq_thold = std::stof (argv[++i]); }
107119 else if (arg == " -tr" || arg == " --translate" ) { params.translate = true ; }
@@ -150,6 +162,12 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
150162 fprintf (stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n " , params.max_tokens );
151163 fprintf (stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n " , params.audio_ctx );
152164 fprintf (stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n " , params.n_gpu_layers );
165+ fprintf (stderr, " --seed N [%-7d] seed sampling\n " , params.seed );
166+ fprintf (stderr, " --top-k N [%-7d] top-k sampling (0 = disabled)\n " , params.top_k );
167+ fprintf (stderr, " --min-keep N [%-7d] minimum number of tokens to keep\n " , params.min_keep );
168+ fprintf (stderr, " --top-p N [%-7.2f] top-p sampling\n " , params.top_p );
169+ fprintf (stderr, " --min-p N [%-7.2f] min-p sampling\n " , params.min_p );
170+ fprintf (stderr, " --temp N [%-7.2f] temperature\n " , params.temp );
153171 fprintf (stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n " , params.vad_thold );
154172 fprintf (stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n " , params.freq_thold );
155173 fprintf (stderr, " -tr, --translate [%-7s] translate from source language to english\n " , params.translate ? " true" : " false" );
@@ -409,21 +427,16 @@ int main(int argc, char ** argv) {
409427 llama_batch batch = llama_batch_init (llama_n_ctx (ctx_llama), 0 , 1 );
410428
411429 // init sampler
412- const float top_k = 5 ;
413- const float top_p = 0 .80f ;
414- const float temp = 0 .30f ;
415-
416- const int seed = 0 ;
417-
418430 auto sparams = llama_sampler_chain_default_params ();
419431
420432 llama_sampler * smpl = llama_sampler_chain_init (sparams);
421433
422- if (temp > 0 .0f ) {
423- llama_sampler_chain_add (smpl, llama_sampler_init_top_k (top_k));
424- llama_sampler_chain_add (smpl, llama_sampler_init_top_p (top_p, 1 ));
425- llama_sampler_chain_add (smpl, llama_sampler_init_temp (temp));
426- llama_sampler_chain_add (smpl, llama_sampler_init_dist (seed));
434+ if (params.temp > 0 .0f ) {
435+ llama_sampler_chain_add (smpl, llama_sampler_init_top_k (params.top_k ));
436+ llama_sampler_chain_add (smpl, llama_sampler_init_top_p (params.top_p , params.min_keep ));
437+ llama_sampler_chain_add (smpl, llama_sampler_init_temp (params.temp ));
438+ llama_sampler_chain_add (smpl, llama_sampler_init_dist (params.seed ));
439+ llama_sampler_chain_add (smpl, llama_sampler_init_min_p (params.min_p , params.min_keep ));
427440 } else {
428441 llama_sampler_chain_add (smpl, llama_sampler_init_greedy ());
429442 }
@@ -615,7 +628,7 @@ int main(int argc, char ** argv) {
615628 }
616629
617630 // remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
618- text_heard = std::regex_replace (text_heard, std::regex (" [^a-zA-Z0-9 \\ .,\\ ?!\\ s\\ :\\ '\\ -]" ), " " );
631+ text_heard = std::regex_replace (text_heard, std::regex (" [^a-zA-Z0-9åäöÅÄÖ \\ .,\\ ?!\\ s\\ :\\ '\\ -]" ), " " );
619632
620633 // take first line
621634 text_heard = text_heard.substr (0 , text_heard.find_first_of (' \n ' ));
0 commit comments