Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 138 additions & 5 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7695,7 +7695,7 @@ int whisper_full_with_state(
whisper_exp_compute_token_level_timestamps(
*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);

if (params.max_len > 0) {
if (params.max_len > 0 && !ctx->params.dtw_token_timestamps) {
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
}
}
Expand All @@ -7708,15 +7708,14 @@ int whisper_full_with_state(
// FIXME: will timestamp offsets be correct?
// [EXPERIMENTAL] Token-level timestamps with DTW
{
const int n_segments = state->result_all.size() - n_segments_before;
int n_segments = state->result_all.size() - n_segments_before;
if (ctx->params.dtw_token_timestamps && n_segments) {
const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
whisper_exp_compute_token_level_timestamps_dtw(
ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);

if (params.new_segment_callback) {
for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
}
params.new_segment_callback(ctx, state, n_segments, params.new_segment_callback_user_data);
}
}
}
Expand Down Expand Up @@ -8949,6 +8948,140 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
}
}

// DTW timestamp refinement constants (in centiseconds, 1 cs = 10ms)
// These values are tuned for natural speech at ~150 WPM
static const int64_t DTW_MIN_TOKEN_DUR = 5; // 50ms absolute minimum
static const int DTW_ONSET_VOWEL = 15; // 150ms for vowels/plosives (anticipate burst)
static const int DTW_ONSET_CONSONANT = 8; // 80ms for other consonants
static const int DTW_DUR_PER_CHAR = 2; // 20ms per character for min duration
static const int DTW_MAX_DUR_PER_CHAR = 15; // 150ms per character for max duration
static const int64_t DTW_MAX_DUR_BASE = 10; // 100ms base max duration

// vowels + plosives benefit from earlier onset to match perceived speech start
static const char * DTW_ONSET_PHONEMES = "aeiouywbcdgkpqt";

// helper: get previous token's end time
auto get_prev_end = [&](size_t seg_idx, int tok_idx) -> int64_t {
auto & seg = state->result_all[seg_idx];
for (int t2 = tok_idx - 1; t2 >= 0; --t2) {
if (seg.tokens[t2].id < whisper_token_eot(ctx)) {
return seg.tokens[t2].t_dtw;
}
}
if (seg_idx > 0 && !state->result_all[seg_idx - 1].tokens.empty()) {
return state->result_all[seg_idx - 1].tokens.back().t_dtw;
}
return 0;
};

// helper: get next token's start time
auto get_next_start = [&](size_t seg_idx, int tok_idx, int64_t fallback) -> int64_t {
auto & seg = state->result_all[seg_idx];
const int n = seg.tokens.size();
for (int t2 = tok_idx + 1; t2 < n; ++t2) {
if (seg.tokens[t2].id < whisper_token_eot(ctx)) {
return seg.tokens[t2].t_dtw;
}
}
if (seg_idx + 1 < state->result_all.size()) {
for (const auto & ntok : state->result_all[seg_idx + 1].tokens) {
if (ntok.id < whisper_token_eot(ctx)) {
return ntok.t_dtw;
}
}
}
return fallback;
};

// helper: get token text length (excluding leading space)
auto get_text_len = [&](whisper_token id) -> std::pair<const char*, int> {
const char * text = whisper_token_to_str(ctx, id);
int len = text ? (int)strlen(text) : 1;
if (len > 0 && text && text[0] == ' ') { text++; len--; }
if (len < 1) len = 1;
return {text, len};
};

// pass 1: onset shift + min duration adjustment
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
auto & segment = state->result_all[i];
const int n_tokens = segment.tokens.size();

for (int t = 0; t < n_tokens; ++t) {
auto & tok = segment.tokens[t];
if (tok.id >= whisper_token_eot(ctx)) continue;

auto text_pair = get_text_len(tok.id);
const char * text = text_pair.first;
int len = text_pair.second;

// onset shift: move start earlier for vowels/plosives
if (len > 0 && text) {
char c = tolower(text[0]);
int shift = 0;
if (strchr(DTW_ONSET_PHONEMES, c)) {
shift = DTW_ONSET_VOWEL;
} else if (c >= 'a' && c <= 'z') {
shift = DTW_ONSET_CONSONANT;
}
if (shift > 0) {
int64_t prev_end = get_prev_end(i, t);
if (tok.t_dtw - shift > prev_end + 1) {
tok.t_dtw -= shift;
}
}
}

// min duration: extend backward if too short
int64_t next_start = get_next_start(i, t, segment.t1);
int64_t duration = next_start - tok.t_dtw;
int64_t len_based_min = (int64_t)(len * DTW_DUR_PER_CHAR);
int64_t adaptive_min = (DTW_MIN_TOKEN_DUR > len_based_min) ? DTW_MIN_TOKEN_DUR : len_based_min;

if (duration >= 0 && duration < adaptive_min) {
int64_t prev_end = get_prev_end(i, t);
int64_t new_start = tok.t_dtw - (adaptive_min - duration);
if (new_start > prev_end + 2) {
tok.t_dtw = new_start;
}
}
}
}

// pass 2: propagate t_dtw to t0/t1 with max duration cap
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
auto & segment = state->result_all[i];
const int n_tokens = segment.tokens.size();

for (int t = 0; t < n_tokens; ++t) {
auto & tok = segment.tokens[t];
if (tok.id >= whisper_token_eot(ctx)) continue;

tok.t0 = tok.t_dtw;

auto text_pair2 = get_text_len(tok.id);
int len2 = text_pair2.second;
int64_t next_start = get_next_start(i, t, segment.t1);
int64_t len_based_max = (int64_t)(len2 * DTW_MAX_DUR_PER_CHAR);
int64_t max_dur = (DTW_MAX_DUR_BASE > len_based_max) ? DTW_MAX_DUR_BASE : len_based_max;

int64_t min_t1 = tok.t0 + DTW_MIN_TOKEN_DUR;
int64_t raw_t1 = (next_start > min_t1) ? next_start : min_t1;
tok.t1 = (raw_t1 - tok.t0 > max_dur) ? tok.t0 + max_dur : raw_t1;
}

// sync segment boundaries with token bounds
int64_t first_t0 = -1, last_t1 = -1;
for (int t = 0; t < n_tokens; ++t) {
const auto & tok = segment.tokens[t];
if (tok.id >= whisper_token_eot(ctx)) continue;
if (first_t0 < 0) first_t0 = tok.t0;
last_t1 = tok.t1;
}
if (first_t0 >= 0) segment.t0 = first_t0;
if (last_t1 >= 0) segment.t1 = last_t1;
}

// Print DTW timestamps
/*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
auto & segment = state->result_all[i];
Expand Down