@@ -867,6 +867,7 @@ struct whisper_state {
867867 whisper_token tid_last;
868868
869869 std::vector<float > energy; // PCM signal energy
870+ float no_speech_prob = 0 .0f ;
870871
871872 // [EXPERIMENTAL] Token-level timestamps with DTW
872873 whisper_aheads_masks aheads_masks;
@@ -4825,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
48254826 " ♪♪♪" ," ♩" , " ♪" , " ♫" , " ♬" , " ♭" , " ♮" , " ♯"
48264827};
48274828
4829+ static void whisper_compute_logprobs (
4830+ const std::vector<float > & logits,
4831+ const int n_logits,
4832+ std::vector<float > & logprobs) {
4833+ const float logit_max = *std::max_element (logits.begin (), logits.end ());
4834+ float logsumexp = 0 .0f ;
4835+ for (int i = 0 ; i < n_logits; ++i) {
4836+ if (logits[i] > -INFINITY) {
4837+ logsumexp += expf (logits[i] - logit_max);
4838+ }
4839+ }
4840+ logsumexp = logf (logsumexp) + logit_max;
4841+
4842+ for (int i = 0 ; i < n_logits; ++i) {
4843+ if (logits[i] > -INFINITY) {
4844+ logprobs[i] = logits[i] - logsumexp;
4845+ } else {
4846+ logprobs[i] = -INFINITY;
4847+ }
4848+ }
4849+ }
4850+
4851+ static void whisper_compute_probs (
4852+ const std::vector<float > & logits,
4853+ const int n_logits,
4854+ const std::vector<float > & logprobs,
4855+ std::vector<float > & probs) {
4856+ for (int i = 0 ; i < n_logits; ++i) {
4857+ if (logits[i] == -INFINITY) {
4858+ probs[i] = 0 .0f ;
4859+ } else {
4860+ probs[i] = expf (logprobs[i]);
4861+ }
4862+ }
4863+ }
4864+
48284865// process the logits for the selected decoder
48294866// - applies logit filters
48304867// - computes logprobs and probs
@@ -4886,7 +4923,7 @@ static void whisper_process_logits(
48864923
48874924 // suppress sot and nosp tokens
48884925 logits[vocab.token_sot ] = -INFINITY;
4889- logits[vocab.token_nosp ] = -INFINITY; // TODO: ignore this token for now
4926+ logits[vocab.token_nosp ] = -INFINITY;
48904927
48914928 // [TDRZ] when tinydiarize is disabled, suppress solm token
48924929 if (params.tdrz_enable == false ) {
@@ -4985,24 +5022,7 @@ static void whisper_process_logits(
49855022 }
49865023
49875024 // populate the logprobs array (log_softmax)
4988- {
4989- const float logit_max = *std::max_element (logits.begin (), logits.end ());
4990- float logsumexp = 0 .0f ;
4991- for (int i = 0 ; i < n_logits; ++i) {
4992- if (logits[i] > -INFINITY) {
4993- logsumexp += expf (logits[i] - logit_max);
4994- }
4995- }
4996- logsumexp = logf (logsumexp) + logit_max;
4997-
4998- for (int i = 0 ; i < n_logits; ++i) {
4999- if (logits[i] > -INFINITY) {
5000- logprobs[i] = logits[i] - logsumexp;
5001- } else {
5002- logprobs[i] = -INFINITY;
5003- }
5004- }
5005- }
5025+ whisper_compute_logprobs (logits, n_logits, logprobs);
50065026
50075027 // if sum of probability over timestamps is above any other token, sample timestamp
50085028 // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
@@ -5060,15 +5080,7 @@ static void whisper_process_logits(
50605080 }
50615081
50625082 // compute probs
5063- {
5064- for (int i = 0 ; i < n_logits; ++i) {
5065- if (logits[i] == -INFINITY) {
5066- probs[i] = 0 .0f ;
5067- } else {
5068- probs[i] = expf (logprobs[i]);
5069- }
5070- }
5071- }
5083+ whisper_compute_probs (logits, n_logits, logprobs, probs);
50725084
50735085#if 0
50745086 // print first 100 logits - token string : logit
@@ -5647,6 +5659,18 @@ int whisper_full_with_state(
56475659 return -8 ;
56485660 }
56495661
5662+ // Calculate no_speech probability after first decode.
5663+ // This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
5664+ {
5665+ const int n_logits = ctx->vocab .id_to_token .size ();
5666+ std::vector<float > logprobs (n_logits);
5667+ std::vector<float > probs (n_logits);
5668+
5669+ whisper_compute_logprobs (state->logits , n_logits, logprobs);
5670+ whisper_compute_probs (state->logits , n_logits, logprobs, probs);
5671+ state->no_speech_prob = probs[whisper_token_nosp (ctx)];
5672+ }
5673+
56505674 {
56515675 const int64_t t_start_sample_us = ggml_time_us ();
56525676
@@ -6038,8 +6062,9 @@ int whisper_full_with_state(
60386062 if (it != (int ) temperatures.size () - 1 ) {
60396063 const auto & decoder = state->decoders [best_decoder_id];
60406064
6041- if (decoder.failed || decoder.sequence .avg_logprobs < params.logprob_thold ) {
6042- WHISPER_LOG_DEBUG (" %s: failed due to avg_logprobs %8.5f < %8.5f\n " , __func__, decoder.sequence .avg_logprobs , params.logprob_thold );
6065+ if (decoder.failed ||
6066+ (decoder.sequence .avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold )) {
6067+ WHISPER_LOG_DEBUG (" %s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n " , __func__, decoder.sequence .avg_logprobs , params.logprob_thold , state->no_speech_prob , params.no_speech_thold );
60436068 success = false ;
60446069 state->n_fail_p ++;
60456070 }
@@ -6068,6 +6093,9 @@ int whisper_full_with_state(
60686093 // [EXPERIMENTAL] Token-level timestamps with DTW
60696094 const auto n_segments_before = state->result_all .size ();
60706095
6096+ const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
6097+ best_decoder.sequence .avg_logprobs < params.logprob_thold );
6098+
60716099 // WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
60726100
60736101 // update prompt_past
@@ -6076,11 +6104,11 @@ int whisper_full_with_state(
60766104 prompt_past.insert (prompt_past.end (), prompt.begin () + 1 , prompt.end () - prompt_init.size ());
60776105 }
60786106
6079- for (int i = 0 ; i < result_len; ++i) {
6107+ for (int i = 0 ; i < result_len && !is_no_speech ; ++i) {
60806108 prompt_past.push_back (tokens_cur[i].id );
60816109 }
60826110
6083- if (!tokens_cur.empty () && ctx->model .n_loaded > 0 ) {
6111+ if (!tokens_cur.empty () && ctx->model .n_loaded > 0 && !is_no_speech ) {
60846112 int i0 = 0 ;
60856113 auto t0 = seek + 2 *(tokens_cur.front ().tid - whisper_token_beg (ctx));
60866114
0 commit comments