Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ extern "C" {
WHISPER_API float * whisper_get_logits (struct whisper_context * ctx);
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);

WHISPER_API int whisper_get_lang_id_from_state(struct whisper_state * state);
WHISPER_API float whisper_get_lang_prob_from_state(struct whisper_state * state);

// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
Expand Down Expand Up @@ -465,6 +468,9 @@ extern "C" {
// Progress callback
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);

// Detected language callback
typedef void (*whisper_detected_language_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);

// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
Expand Down Expand Up @@ -562,6 +568,10 @@ extern "C" {
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;

// called on detected language
whisper_detected_language_callback detected_language_callback;
void * detected_language_callback_user_data;

// called on each progress update
whisper_progress_callback progress_callback;
void * progress_callback_user_data;
Expand Down
20 changes: 19 additions & 1 deletion src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ struct whisper_state {
std::vector<whisper_token> prompt_past1; // dynamic context from decoded output

int lang_id = 0; // english by default
float lang_prob = 0.0f; // probability of the detected language

std::string path_model; // populated by whisper_init_from_file_with_params()

Expand Down Expand Up @@ -4198,6 +4199,14 @@ float * whisper_get_logits_from_state(struct whisper_state * state) {
return state->logits.data();
}

int whisper_get_lang_id_from_state(struct whisper_state * state) {
return state->lang_id;
}

float whisper_get_lang_prob_from_state(struct whisper_state * state) {
return state->lang_prob;
}

const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
return ctx->vocab.id_to_token.at(token).c_str();
}
Expand Down Expand Up @@ -5968,6 +5977,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,

/*.detected_language_callback =*/ nullptr,
/*.detected_language_callback_user_data =*/ nullptr,

/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,

Expand Down Expand Up @@ -6818,9 +6830,15 @@ int whisper_full_with_state(
return -3;
}
state->lang_id = lang_id;
state->lang_prob = probs[lang_id];
params.language = whisper_lang_str(lang_id);

WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[lang_id]);

if (params.detected_language_callback) {
params.detected_language_callback(ctx, state, params.detected_language_callback_user_data);
}

if (params.detect_language) {
return 0;
}
Expand Down