Skip to content

Commit f897eb7

Browse files
authored
whisper : support no_speech_thold (#2625)
* Implement no_speech_thold no_speech_thold functionality is on par with OpenAI's whisper * Addressed review comments
1 parent 2f2841b commit f897eb7

File tree

2 files changed

+61
-33
lines changed

2 files changed

+61
-33
lines changed

include/whisper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ extern "C" {
534534
float temperature_inc;
535535
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
536536
float logprob_thold;
537-
float no_speech_thold; // TODO: not implemented
537+
float no_speech_thold;
538538

539539
struct {
540540
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264

src/whisper.cpp

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,7 @@ struct whisper_state {
867867
whisper_token tid_last;
868868

869869
std::vector<float> energy; // PCM signal energy
870+
float no_speech_prob = 0.0f;
870871

871872
// [EXPERIMENTAL] Token-level timestamps with DTW
872873
whisper_aheads_masks aheads_masks;
@@ -4825,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
48254826
"♪♪♪","", "", "", "", "", "", ""
48264827
};
48274828

4829+
static void whisper_compute_logprobs(
4830+
const std::vector<float> & logits,
4831+
const int n_logits,
4832+
std::vector<float> & logprobs) {
4833+
const float logit_max = *std::max_element(logits.begin(), logits.end());
4834+
float logsumexp = 0.0f;
4835+
for (int i = 0; i < n_logits; ++i) {
4836+
if (logits[i] > -INFINITY) {
4837+
logsumexp += expf(logits[i] - logit_max);
4838+
}
4839+
}
4840+
logsumexp = logf(logsumexp) + logit_max;
4841+
4842+
for (int i = 0; i < n_logits; ++i) {
4843+
if (logits[i] > -INFINITY) {
4844+
logprobs[i] = logits[i] - logsumexp;
4845+
} else {
4846+
logprobs[i] = -INFINITY;
4847+
}
4848+
}
4849+
}
4850+
4851+
static void whisper_compute_probs(
4852+
const std::vector<float> & logits,
4853+
const int n_logits,
4854+
const std::vector<float> & logprobs,
4855+
std::vector<float> & probs) {
4856+
for (int i = 0; i < n_logits; ++i) {
4857+
if (logits[i] == -INFINITY) {
4858+
probs[i] = 0.0f;
4859+
} else {
4860+
probs[i] = expf(logprobs[i]);
4861+
}
4862+
}
4863+
}
4864+
48284865
// process the logits for the selected decoder
48294866
// - applies logit filters
48304867
// - computes logprobs and probs
@@ -4886,7 +4923,7 @@ static void whisper_process_logits(
48864923

48874924
// suppress sot and nosp tokens
48884925
logits[vocab.token_sot] = -INFINITY;
4889-
logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
4926+
logits[vocab.token_nosp] = -INFINITY;
48904927

48914928
// [TDRZ] when tinydiarize is disabled, suppress solm token
48924929
if (params.tdrz_enable == false) {
@@ -4985,24 +5022,7 @@ static void whisper_process_logits(
49855022
}
49865023

49875024
// populate the logprobs array (log_softmax)
4988-
{
4989-
const float logit_max = *std::max_element(logits.begin(), logits.end());
4990-
float logsumexp = 0.0f;
4991-
for (int i = 0; i < n_logits; ++i) {
4992-
if (logits[i] > -INFINITY) {
4993-
logsumexp += expf(logits[i] - logit_max);
4994-
}
4995-
}
4996-
logsumexp = logf(logsumexp) + logit_max;
4997-
4998-
for (int i = 0; i < n_logits; ++i) {
4999-
if (logits[i] > -INFINITY) {
5000-
logprobs[i] = logits[i] - logsumexp;
5001-
} else {
5002-
logprobs[i] = -INFINITY;
5003-
}
5004-
}
5005-
}
5025+
whisper_compute_logprobs(logits, n_logits, logprobs);
50065026

50075027
// if sum of probability over timestamps is above any other token, sample timestamp
50085028
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
@@ -5060,15 +5080,7 @@ static void whisper_process_logits(
50605080
}
50615081

50625082
// compute probs
5063-
{
5064-
for (int i = 0; i < n_logits; ++i) {
5065-
if (logits[i] == -INFINITY) {
5066-
probs[i] = 0.0f;
5067-
} else {
5068-
probs[i] = expf(logprobs[i]);
5069-
}
5070-
}
5071-
}
5083+
whisper_compute_probs(logits, n_logits, logprobs, probs);
50725084

50735085
#if 0
50745086
// print first 100 logits - token string : logit
@@ -5647,6 +5659,18 @@ int whisper_full_with_state(
56475659
return -8;
56485660
}
56495661

5662+
// Calculate no_speech probability after first decode.
5663+
// This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
5664+
{
5665+
const int n_logits = ctx->vocab.id_to_token.size();
5666+
std::vector<float> logprobs(n_logits);
5667+
std::vector<float> probs(n_logits);
5668+
5669+
whisper_compute_logprobs(state->logits, n_logits, logprobs);
5670+
whisper_compute_probs(state->logits, n_logits, logprobs, probs);
5671+
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
5672+
}
5673+
56505674
{
56515675
const int64_t t_start_sample_us = ggml_time_us();
56525676

@@ -6038,8 +6062,9 @@ int whisper_full_with_state(
60386062
if (it != (int) temperatures.size() - 1) {
60396063
const auto & decoder = state->decoders[best_decoder_id];
60406064

6041-
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
6042-
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
6065+
if (decoder.failed ||
6066+
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
6067+
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
60436068
success = false;
60446069
state->n_fail_p++;
60456070
}
@@ -6068,6 +6093,9 @@ int whisper_full_with_state(
60686093
// [EXPERIMENTAL] Token-level timestamps with DTW
60696094
const auto n_segments_before = state->result_all.size();
60706095

6096+
const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
6097+
best_decoder.sequence.avg_logprobs < params.logprob_thold);
6098+
60716099
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
60726100

60736101
// update prompt_past
@@ -6076,11 +6104,11 @@ int whisper_full_with_state(
60766104
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
60776105
}
60786106

6079-
for (int i = 0; i < result_len; ++i) {
6107+
for (int i = 0; i < result_len && !is_no_speech; ++i) {
60806108
prompt_past.push_back(tokens_cur[i].id);
60816109
}
60826110

6083-
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
6111+
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
60846112
int i0 = 0;
60856113
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
60866114

0 commit comments

Comments
 (0)