From f9e225b05c3237866fc0ddc41a3cfe2c03d0071b Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Thu, 28 Aug 2025 14:37:08 -0700 Subject: [PATCH 1/5] wip --- Cargo.lock | 34 ++++++++++-- crates/whisper-local/Cargo.toml | 3 +- crates/whisper-local/src/bias.rs | 1 + crates/whisper-local/src/model.rs | 88 +++++++++++++++++++++++++++---- 4 files changed, 112 insertions(+), 14 deletions(-) create mode 100644 crates/whisper-local/src/bias.rs diff --git a/Cargo.lock b/Cargo.lock index 6e71a66b9..2eeacfad8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4541,6 +4541,15 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" +[[package]] +name = "fid-rs" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6956a1e60e2d1412b44b4169d44a03dae518f8583d3e10090c912c105e48447" +dependencies = [ + "rayon", +] + [[package]] name = "field-offset" version = "0.3.6" @@ -8607,6 +8616,15 @@ dependencies = [ "url", ] +[[package]] +name = "louds-rs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "936de6c22f08e7135a921f8ada907acd0d88880c4f42b5591f634b9f1dd8e07f" +dependencies = [ + "fid-rs", +] + [[package]] name = "lru" version = "0.12.5" @@ -16455,6 +16473,15 @@ dependencies = [ "petgraph 0.6.5", ] +[[package]] +name = "trie-rs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f88f4b0a1ebd6c3d16be3e45eb0e8089372ccadd88849b7ca162ba64b5e6f6" +dependencies = [ + "louds-rs", +] + [[package]] name = "trim-in-place" version = "0.1.7" @@ -17509,6 +17536,7 @@ dependencies = [ "thiserror 2.0.16", "tokio", "tracing", + "trie-rs", "whisper", "whisper-rs", ] @@ -17525,8 +17553,7 @@ dependencies = [ [[package]] name = "whisper-rs" version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4a5eb3a2a84d3adfb2e84b3ae783ee73428676582b2c91cb66c3091e737256" +source = "git+https://codeberg.org/tazz4843/whisper-rs?rev=3e6d3da#3e6d3da162146e8697fc39c4a81e66a339f14bc6" dependencies = [ "libc", "tracing", @@ -17536,8 +17563,7 @@ dependencies = [ [[package]] name = "whisper-rs-sys" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d0927ebac46387e6f3f705cc007c9ea40885f6d924071a918f1ee8b829cf5cd" +source = "git+https://codeberg.org/tazz4843/whisper-rs?rev=3e6d3da#3e6d3da162146e8697fc39c4a81e66a339f14bc6" dependencies = [ "bindgen 0.71.1", "cfg-if", diff --git a/crates/whisper-local/Cargo.toml b/crates/whisper-local/Cargo.toml index 90c18d982..68379f85e 100644 --- a/crates/whisper-local/Cargo.toml +++ b/crates/whisper-local/Cargo.toml @@ -28,7 +28,7 @@ hypr-whisper = { workspace = true } dasp = { workspace = true } kalosm-sound = { workspace = true, default-features = false } rodio = { workspace = true } -whisper-rs = { version = "0.15.0", features = ["raw-api", "tracing_backend"] } +whisper-rs = { git = "https://codeberg.org/tazz4843/whisper-rs", rev = "3e6d3da", features = ["raw-api", "tracing_backend"] } futures-util = { workspace = true } tracing = { workspace = true } @@ -37,6 +37,7 @@ serde = { workspace = true } serde_json = { workspace = true } specta = { workspace = true, features = ["derive"] } thiserror = { workspace = true } +trie-rs = "0.4.2" lazy_static = { workspace = true } regex = { workspace = true } diff --git a/crates/whisper-local/src/bias.rs b/crates/whisper-local/src/bias.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/crates/whisper-local/src/bias.rs @@ -0,0 +1 @@ + diff --git a/crates/whisper-local/src/model.rs b/crates/whisper-local/src/model.rs index 4f8a0c5b6..427962fbd 100644 --- a/crates/whisper-local/src/model.rs +++ b/crates/whisper-local/src/model.rs @@ -3,6 +3,7 @@ use lazy_static::lazy_static; use regex::Regex; +use trie_rs::map::{Trie, TrieBuilder}; use whisper_rs::{ FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperState, WhisperTokenId, @@ -52,11 +53,31 @@ impl WhisperBuilder { let state = ctx.create_state()?; let token_beg = ctx.token_beg(); + let bias_trie = { + let custom_vacab = vec!["Hyprnote", "OWhisper"]; + let sequences = custom_vacab + .iter() + .map(|s| ctx.tokenize(s, 99)) + .collect::, _>>()?; + + let mut builder = TrieBuilder::new(); + + for sequence in sequences { + for i in 1..=sequence.len() { + let progress = i as f32 / sequence.len() as f32; + let prefix_bias = 1.0 + progress.powi(2); + builder.push(&sequence[..i], prefix_bias); + } + } + builder.build() + }; + Ok(Whisper { languages: self.languages.unwrap_or_default(), dynamic_prompt: "".to_string(), state, token_beg, + bias_trie, }) } @@ -76,6 +97,7 @@ pub struct Whisper { dynamic_prompt: String, state: WhisperState, token_beg: WhisperTokenId, + bias_trie: Trie, } impl Whisper { @@ -109,7 +131,7 @@ impl Whisper { p.set_initial_prompt(&initial_prompt); unsafe { - Self::suppress_beg(&mut p, &token_beg); + Self::set_logit_filter(&mut p, &token_beg, &self.bias_trie); } p.set_no_timestamps(true); @@ -236,12 +258,26 @@ impl Whisper { .collect() } - unsafe fn suppress_beg(params: &mut FullParams, token_beg: &WhisperTokenId) { + unsafe fn set_logit_filter( + params: &mut FullParams, + token_beg: &WhisperTokenId, + bias_trie: &Trie, + ) { + struct Context { + token_beg: WhisperTokenId, + bias_trie: Trie, + } + + let context = Box::new(Context { + token_beg: *token_beg, + bias_trie: bias_trie.clone(), + }); + unsafe extern "C" fn logits_filter_callback( _ctx: *mut whisper_rs::whisper_rs_sys::whisper_context, _state: *mut whisper_rs::whisper_rs_sys::whisper_state, - _tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, - _n_tokens: std::os::raw::c_int, + tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, + n_tokens: std::os::raw::c_int, logits: *mut f32, user_data: *mut std::os::raw::c_void, ) { @@ -249,14 +285,48 @@ impl Whisper { return; } - let token_beg_id = *(user_data as *const WhisperTokenId); - *logits.offset(token_beg_id as isize) = f32::NEG_INFINITY; + let context = &*(user_data as *const Context); + + { + *logits.offset(context.token_beg as isize) = f32::NEG_INFINITY; + } + + { + if !tokens.is_null() && n_tokens > 0 { + let current_tokens: Vec = + std::slice::from_raw_parts(tokens, n_tokens as usize) + .iter() + .map(|t| t.id) + .collect(); + + for start_pos in (n_tokens as usize).saturating_sub(10)..n_tokens as usize { + let suffix = ¤t_tokens[start_pos..]; + + if let Some(_) = context.bias_trie.exact_match(suffix) { + continue; + } + + for (full_sequence, bias_value_ref) in + context.bias_trie.predictive_search(suffix) + { + let bias_value = *bias_value_ref; + let full_sequence: Vec = full_sequence; + + if full_sequence.len() > suffix.len() { + let next_token = full_sequence[suffix.len()]; + let current_logit = *logits.offset(next_token as isize); + *logits.offset(next_token as isize) = + current_logit + bias_value.ln(); + } + } + } + } + } } params.set_filter_logits_callback(Some(logits_filter_callback)); - params.set_filter_logits_callback_user_data( - token_beg as *const WhisperTokenId as *mut std::ffi::c_void, - ); + params + .set_filter_logits_callback_user_data(Box::into_raw(context) as *mut std::ffi::c_void); } } From 1e33df1fde259de7d34e8342996a93c58509fbd0 Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Thu, 28 Aug 2025 15:04:21 -0700 Subject: [PATCH 2/5] refactor --- crates/whisper-local/src/bias.rs | 65 ++++++++++++++++++++++++++++ crates/whisper-local/src/lib.rs | 3 ++ crates/whisper-local/src/model.rs | 72 ++++++++----------------------- 3 files changed, 86 insertions(+), 54 deletions(-) diff --git a/crates/whisper-local/src/bias.rs b/crates/whisper-local/src/bias.rs index 8b1378917..4a4184987 100644 --- a/crates/whisper-local/src/bias.rs +++ b/crates/whisper-local/src/bias.rs @@ -1 +1,66 @@ +use trie_rs::map::{Trie, TrieBuilder}; +use whisper_rs::{WhisperContext, WhisperTokenId}; +#[derive(Clone)] +pub struct BiasTrie { + trie: Trie, +} + +impl BiasTrie { + pub fn new(ctx: &WhisperContext, custom_vocab: &[&str]) -> Result { + let sequences = custom_vocab + .iter() + .map(|s| ctx.tokenize(s, 99)) + .collect::, _>>()?; + + let mut builder = TrieBuilder::new(); + + for sequence in sequences { + for i in 1..=sequence.len() { + let progress = i as f32 / sequence.len() as f32; + let prefix_bias = 1.0 + 2.0 * progress.powi(2); + builder.push(&sequence[..i], prefix_bias); + } + } + + Ok(BiasTrie { + trie: builder.build(), + }) + } + + pub unsafe fn apply_bias_to_logits( + &self, + tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, + n_tokens: std::os::raw::c_int, + logits: *mut f32, + ) { + if tokens.is_null() || n_tokens <= 0 { + return; + } + + let current_tokens: Vec = + std::slice::from_raw_parts(tokens, n_tokens as usize) + .iter() + .map(|t| t.id) + .collect(); + + for start_pos in (n_tokens as usize).saturating_sub(10)..n_tokens as usize { + let suffix = ¤t_tokens[start_pos..]; + + if self.trie.exact_match(suffix).is_some() { + continue; + } + + for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) { + let bias_value = *bias_value_ref; + let full_sequence: Vec = full_sequence; + + if full_sequence.len() > suffix.len() { + let next_token = full_sequence[suffix.len()]; + let current_logit = *logits.offset(next_token as isize); + *logits.offset(next_token as isize) = current_logit + bias_value.ln(); + } + } + } + } +} diff --git a/crates/whisper-local/src/lib.rs b/crates/whisper-local/src/lib.rs index 057d1e756..d29092a58 100644 --- a/crates/whisper-local/src/lib.rs +++ b/crates/whisper-local/src/lib.rs @@ -9,6 +9,9 @@ pub use model::*; mod error; pub use error::*; +mod bias; +use bias::*; + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] pub struct GgmlBackend { pub kind: String, diff --git a/crates/whisper-local/src/model.rs b/crates/whisper-local/src/model.rs index 427962fbd..7b44cc3b2 100644 --- a/crates/whisper-local/src/model.rs +++ b/crates/whisper-local/src/model.rs @@ -3,12 +3,12 @@ use lazy_static::lazy_static; use regex::Regex; -use trie_rs::map::{Trie, TrieBuilder}; use whisper_rs::{ FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperState, WhisperTokenId, }; +use crate::BiasTrie; use hypr_whisper::Language; lazy_static! { @@ -19,6 +19,7 @@ lazy_static! { pub struct WhisperBuilder { model_path: Option, languages: Option>, + vocab: Option>, } impl WhisperBuilder { @@ -32,6 +33,11 @@ impl WhisperBuilder { self } + pub fn vocab(mut self, vocab: Vec) -> Self { + self.vocab = Some(vocab); + self + } + pub fn build(self) -> Result { unsafe { Self::suppress_log() }; @@ -54,22 +60,10 @@ impl WhisperBuilder { let token_beg = ctx.token_beg(); let bias_trie = { - let custom_vacab = vec!["Hyprnote", "OWhisper"]; - let sequences = custom_vacab - .iter() - .map(|s| ctx.tokenize(s, 99)) - .collect::, _>>()?; - - let mut builder = TrieBuilder::new(); - - for sequence in sequences { - for i in 1..=sequence.len() { - let progress = i as f32 / sequence.len() as f32; - let prefix_bias = 1.0 + progress.powi(2); - builder.push(&sequence[..i], prefix_bias); - } - } - builder.build() + let custom_vocab = self.vocab.unwrap_or(vec!["Hyprnote".to_string()]); + + let custom_vocab_refs: Vec<&str> = custom_vocab.iter().map(|s| s.as_str()).collect(); + BiasTrie::new(&ctx, &custom_vocab_refs)? }; Ok(Whisper { @@ -97,7 +91,7 @@ pub struct Whisper { dynamic_prompt: String, state: WhisperState, token_beg: WhisperTokenId, - bias_trie: Trie, + bias_trie: BiasTrie, } impl Whisper { @@ -261,11 +255,11 @@ impl Whisper { unsafe fn set_logit_filter( params: &mut FullParams, token_beg: &WhisperTokenId, - bias_trie: &Trie, + bias_trie: &BiasTrie, ) { struct Context { token_beg: WhisperTokenId, - bias_trie: Trie, + bias_trie: BiasTrie, } let context = Box::new(Context { @@ -287,41 +281,11 @@ impl Whisper { let context = &*(user_data as *const Context); - { - *logits.offset(context.token_beg as isize) = f32::NEG_INFINITY; - } + *logits.offset(context.token_beg as isize) = f32::NEG_INFINITY; - { - if !tokens.is_null() && n_tokens > 0 { - let current_tokens: Vec = - std::slice::from_raw_parts(tokens, n_tokens as usize) - .iter() - .map(|t| t.id) - .collect(); - - for start_pos in (n_tokens as usize).saturating_sub(10)..n_tokens as usize { - let suffix = ¤t_tokens[start_pos..]; - - if let Some(_) = context.bias_trie.exact_match(suffix) { - continue; - } - - for (full_sequence, bias_value_ref) in - context.bias_trie.predictive_search(suffix) - { - let bias_value = *bias_value_ref; - let full_sequence: Vec = full_sequence; - - if full_sequence.len() > suffix.len() { - let next_token = full_sequence[suffix.len()]; - let current_logit = *logits.offset(next_token as isize); - *logits.offset(next_token as isize) = - current_logit + bias_value.ln(); - } - } - } - } - } + context + .bias_trie + .apply_bias_to_logits(tokens, n_tokens, logits); } params.set_filter_logits_callback(Some(logits_filter_callback)); From 607a66dd95ec5e9d20997f5b5fe3214f81c89c1e Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Thu, 28 Aug 2025 16:20:48 -0700 Subject: [PATCH 3/5] done --- Cargo.lock | 1 + .../src/service/streaming.rs | 17 ++-- crates/whisper-local/Cargo.toml | 5 ++ .../benches/whisper_transcription.rs | 88 +++++++++++++++++++ crates/whisper-local/src/bias.rs | 28 +++--- crates/whisper-local/src/model.rs | 27 ++++-- owhisper/owhisper-interface/src/lib.rs | 3 + plugins/listener/src/fsm.rs | 21 ++++- 8 files changed, 163 insertions(+), 27 deletions(-) create mode 100644 crates/whisper-local/benches/whisper_transcription.rs diff --git a/Cargo.lock b/Cargo.lock index 2eeacfad8..c53a6e529 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17522,6 +17522,7 @@ name = "whisper-local" version = "0.1.0" dependencies = [ "audio-utils", + "criterion", "dasp", "data", "dirs 6.0.0", diff --git a/crates/transcribe-whisper-local/src/service/streaming.rs b/crates/transcribe-whisper-local/src/service/streaming.rs index 373d3ab3b..d5574a27e 100644 --- a/crates/transcribe-whisper-local/src/service/streaming.rs +++ b/crates/transcribe-whisper-local/src/service/streaming.rs @@ -90,15 +90,18 @@ where } }; + let languages = params + .languages + .iter() + .filter_map(|lang| lang.clone().try_into().ok()) + .collect::>(); + + let vocabulary = params.vocabulary.clone(); + let model = match hypr_whisper_local::Whisper::builder() .model_path(model_path.to_str().unwrap()) - .languages( - params - .languages - .iter() - .filter_map(|lang| lang.clone().try_into().ok()) - .collect::>(), - ) + .languages(languages) + .vocabulary(vocabulary) .build() { Ok(model) => model, diff --git a/crates/whisper-local/Cargo.toml b/crates/whisper-local/Cargo.toml index 68379f85e..2134cc8ed 100644 --- a/crates/whisper-local/Cargo.toml +++ b/crates/whisper-local/Cargo.toml @@ -17,10 +17,15 @@ openmp = ["whisper-rs/openmp"] [dev-dependencies] hypr-data = { workspace = true } +criterion = { workspace = true } dirs = { workspace = true } futures-util = { workspace = true } tokio = { workspace = true } +[[bench]] +name = "whisper_transcription" +harness = false + [dependencies] hypr-audio-utils = { workspace = true } hypr-whisper = { workspace = true } diff --git a/crates/whisper-local/benches/whisper_transcription.rs b/crates/whisper-local/benches/whisper_transcription.rs new file mode 100644 index 000000000..cf65a4c1f --- /dev/null +++ b/crates/whisper-local/benches/whisper_transcription.rs @@ -0,0 +1,88 @@ +use std::hint::black_box; +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, Criterion}; +use hypr_whisper::Language; +use whisper_local::Whisper; + +fn benchmark_whisper_transcription(c: &mut Criterion) { + let audio: Vec = hypr_data::english_1::AUDIO + .chunks_exact(2) + .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / 32768.0) + .collect(); + + let model_path = concat!(env!("CARGO_MANIFEST_DIR"), "/model.bin"); + + let mut whisper_without_vocab = Whisper::builder() + .model_path(model_path) + .languages(vec![Language::En]) + .build() + .unwrap(); + + let mut whisper_with_vocab = Whisper::builder() + .model_path(model_path) + .languages(vec![Language::En]) + .vocabulary( + vec![ + "profound", + "acquire", + "complementary", + "deeply", + "repositories", + "brilliant", + "pockets", + "thread", + "stumbling", + "stumble", + "communities", + "invested", + "undergrad", + "Googleable", + "exploring", + "neuroscientist", + "psychology", + "engineering", + "researcher", + "thinker", + "skill", + "invest", + "solved", + "entire", + "especially", + "actually", + "often", + "already", + "important", + "definitely", + "much", + ] + .into_iter() + .map(|s| s.into()) + .collect(), + ) + .build() + .unwrap(); + + let mut group = c.benchmark_group("whisper_comparison"); + group.measurement_time(Duration::from_secs(100)); + group.sample_size(10); + + group.bench_function("without_vocab", |b| { + b.iter(|| { + let segments = whisper_without_vocab.transcribe(black_box(&audio)).unwrap(); + black_box(segments) + }) + }); + + group.bench_function("with_vocab", |b| { + b.iter(|| { + let segments = whisper_with_vocab.transcribe(black_box(&audio)).unwrap(); + black_box(segments) + }) + }); + + group.finish(); +} + +criterion_group!(benches, benchmark_whisper_transcription); +criterion_main!(benches); diff --git a/crates/whisper-local/src/bias.rs b/crates/whisper-local/src/bias.rs index 4a4184987..2270b3503 100644 --- a/crates/whisper-local/src/bias.rs +++ b/crates/whisper-local/src/bias.rs @@ -18,14 +18,14 @@ impl BiasTrie { for sequence in sequences { for i in 1..=sequence.len() { let progress = i as f32 / sequence.len() as f32; - let prefix_bias = 1.0 + 2.0 * progress.powi(2); - builder.push(&sequence[..i], prefix_bias); + let prefix_bias = 1.0 + 10.0 * progress.powi(2); + let prefix = &sequence[..i]; + builder.push(prefix, prefix_bias); } } + let trie = builder.build(); - Ok(BiasTrie { - trie: builder.build(), - }) + Ok(BiasTrie { trie }) } pub unsafe fn apply_bias_to_logits( @@ -44,23 +44,29 @@ impl BiasTrie { .map(|t| t.id) .collect(); - for start_pos in (n_tokens as usize).saturating_sub(10)..n_tokens as usize { - let suffix = ¤t_tokens[start_pos..]; + for suffix_len in 1..=std::cmp::min(5, current_tokens.len()) { + let suffix = ¤t_tokens[current_tokens.len() - suffix_len..]; - if self.trie.exact_match(suffix).is_some() { - continue; - } + let mut found_continuations = false; for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) { + found_continuations = true; + let bias_value = *bias_value_ref; let full_sequence: Vec = full_sequence; if full_sequence.len() > suffix.len() { let next_token = full_sequence[suffix.len()]; let current_logit = *logits.offset(next_token as isize); - *logits.offset(next_token as isize) = current_logit + bias_value.ln(); + let new_logit = current_logit + bias_value.ln(); + + *logits.offset(next_token as isize) = new_logit; } } + + if found_continuations { + break; + } } } } diff --git a/crates/whisper-local/src/model.rs b/crates/whisper-local/src/model.rs index 7b44cc3b2..c91c0bcb9 100644 --- a/crates/whisper-local/src/model.rs +++ b/crates/whisper-local/src/model.rs @@ -19,7 +19,7 @@ lazy_static! { pub struct WhisperBuilder { model_path: Option, languages: Option>, - vocab: Option>, + vocabulary: Option>, } impl WhisperBuilder { @@ -33,8 +33,8 @@ impl WhisperBuilder { self } - pub fn vocab(mut self, vocab: Vec) -> Self { - self.vocab = Some(vocab); + pub fn vocabulary(mut self, vocabulary: Vec) -> Self { + self.vocabulary = Some(vocabulary); self } @@ -60,7 +60,7 @@ impl WhisperBuilder { let token_beg = ctx.token_beg(); let bias_trie = { - let custom_vocab = self.vocab.unwrap_or(vec!["Hyprnote".to_string()]); + let custom_vocab = self.vocabulary.unwrap_or(vec!["Hyprnote".to_string()]); let custom_vocab_refs: Vec<&str> = custom_vocab.iter().map(|s| s.as_str()).collect(); BiasTrie::new(&ctx, &custom_vocab_refs)? @@ -343,6 +343,15 @@ mod tests { fn test_whisper() { let mut whisper = Whisper::builder() .model_path(concat!(env!("CARGO_MANIFEST_DIR"), "/model.bin")) + .vocabulary( + vec![ + "Google", "should", "people", "question", "learning", "research", "problem", + "like", "actually", + ] + .into_iter() + .map(|s| s.into()) + .collect(), + ) .build() .unwrap(); @@ -352,7 +361,15 @@ mod tests { .collect(); let segments = whisper.transcribe(&audio).unwrap(); - println!("segments: {:#?}", segments); assert!(segments.len() > 0); + + println!( + "{}", + segments + .iter() + .map(|s| s.text.clone()) + .collect::>() + .join(" ") + ); } } diff --git a/owhisper/owhisper-interface/src/lib.rs b/owhisper/owhisper-interface/src/lib.rs index c7f0c2c29..213956aa5 100644 --- a/owhisper/owhisper-interface/src/lib.rs +++ b/owhisper/owhisper-interface/src/lib.rs @@ -125,6 +125,8 @@ common_derives! { // https://docs.rs/axum-extra/0.10.1/axum_extra/extract/struct.Query.html#example-1 #[serde(default)] pub languages: Vec, + #[serde(default)] + pub vocabulary: Vec, pub redemption_time_ms: Option, } } @@ -135,6 +137,7 @@ impl Default for ListenParams { model: None, channels: 1, languages: vec![], + vocabulary: vec![], redemption_time_ms: None, } } diff --git a/plugins/listener/src/fsm.rs b/plugins/listener/src/fsm.rs index e6d09f6c8..880ba5624 100644 --- a/plugins/listener/src/fsm.rs +++ b/plugins/listener/src/fsm.rs @@ -209,7 +209,7 @@ impl Session { let user_id = self.app.db_user_id().await?.unwrap(); self.session_id = Some(session_id.clone()); - let (record, languages) = { + let (record, languages, vocabulary) = { let config = self.app.db_get_config(&user_id).await?; let record = config @@ -221,7 +221,12 @@ impl Session { |c| c.general.spoken_languages.clone(), ); - (record, languages) + let vocabulary = config.as_ref().map_or_else( + || vec!["Hyprnote".to_string()], + |c| c.general.jargons.clone(), + ); + + (record, languages, vocabulary) }; let session = self @@ -243,8 +248,14 @@ impl Session { self.speaker_muted_rx = Some(speaker_muted_rx_main.clone()); self.session_state_tx = Some(session_state_tx); - let listen_client = - setup_listen_client(&self.app, languages, session_id == onboarding_session_id).await?; + let listen_client = setup_listen_client( + &self.app, + vocabulary, + languages, + session_id == onboarding_session_id, + ) + .await?; + let mic_sample_stream = { let mut input = hypr_audio::AudioInput::from_mic(self.mic_device_name.clone())?; input.stream() @@ -603,6 +614,7 @@ impl Session { async fn setup_listen_client( app: &tauri::AppHandle, + vocabulary: Vec, languages: Vec, is_onboarding: bool, ) -> Result { @@ -615,6 +627,7 @@ async fn setup_listen_client( .api_base(conn.base_url) .api_key(conn.api_key.unwrap_or_default()) .params(owhisper_interface::ListenParams { + vocabulary, languages, redemption_time_ms: Some(if is_onboarding { 60 } else { 400 }), ..Default::default() From ff2bc1a51bba39079474c97eb39ccf42e701b714 Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Thu, 28 Aug 2025 17:35:22 -0700 Subject: [PATCH 4/5] fix memory leak --- crates/whisper-local/src/model.rs | 41 +++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/crates/whisper-local/src/model.rs b/crates/whisper-local/src/model.rs index c91c0bcb9..f478086c8 100644 --- a/crates/whisper-local/src/model.rs +++ b/crates/whisper-local/src/model.rs @@ -109,7 +109,7 @@ impl Whisper { let token_beg = self.token_beg; let language = self.get_language(audio)?; - let params = { + let mut params = { let mut p = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); let parts = [self.dynamic_prompt.trim()]; @@ -124,10 +124,6 @@ impl Whisper { p.set_initial_prompt(&initial_prompt); - unsafe { - Self::set_logit_filter(&mut p, &token_beg, &self.bias_trie); - } - p.set_no_timestamps(true); p.set_token_timestamps(false); p.set_split_on_word(true); @@ -146,6 +142,8 @@ impl Whisper { p }; + let _guard = unsafe { Self::set_logit_filter(&mut params, &token_beg, &self.bias_trie) }; + self.state.full(params, &audio[..])?; let num_segments = self.state.full_n_segments(); @@ -256,12 +254,7 @@ impl Whisper { params: &mut FullParams, token_beg: &WhisperTokenId, bias_trie: &BiasTrie, - ) { - struct Context { - token_beg: WhisperTokenId, - bias_trie: BiasTrie, - } - + ) -> LogitFilterGuard { let context = Box::new(Context { token_beg: *token_beg, bias_trie: bias_trie.clone(), @@ -288,9 +281,31 @@ impl Whisper { .apply_bias_to_logits(tokens, n_tokens, logits); } + let context_ptr = Box::into_raw(context) as *mut std::ffi::c_void; + params.set_filter_logits_callback(Some(logits_filter_callback)); - params - .set_filter_logits_callback_user_data(Box::into_raw(context) as *mut std::ffi::c_void); + params.set_filter_logits_callback_user_data(context_ptr); + + LogitFilterGuard { context_ptr } + } +} + +struct Context { + token_beg: WhisperTokenId, + bias_trie: BiasTrie, +} + +struct LogitFilterGuard { + context_ptr: *mut std::ffi::c_void, +} + +impl Drop for LogitFilterGuard { + fn drop(&mut self) { + if !self.context_ptr.is_null() { + unsafe { + let _ = Box::from_raw(self.context_ptr as *mut Context); + } + } } } From 8adcf400fc65b4b4a0b8f1aa994ead12eac70614 Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Fri, 29 Aug 2025 09:31:12 -0700 Subject: [PATCH 5/5] wip --- crates/whisper-local/src/bias.rs | 65 +++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/crates/whisper-local/src/bias.rs b/crates/whisper-local/src/bias.rs index 2270b3503..d884b59b9 100644 --- a/crates/whisper-local/src/bias.rs +++ b/crates/whisper-local/src/bias.rs @@ -8,26 +8,53 @@ pub struct BiasTrie { impl BiasTrie { pub fn new(ctx: &WhisperContext, custom_vocab: &[&str]) -> Result { - let sequences = custom_vocab - .iter() - .map(|s| ctx.tokenize(s, 99)) - .collect::, _>>()?; - let mut builder = TrieBuilder::new(); - for sequence in sequences { - for i in 1..=sequence.len() { - let progress = i as f32 / sequence.len() as f32; - let prefix_bias = 1.0 + 10.0 * progress.powi(2); - let prefix = &sequence[..i]; - builder.push(prefix, prefix_bias); + for word in custom_vocab { + let variants = Self::generate_tokenization_variants(ctx, word)?; + + for tokens in variants { + for i in 1..=tokens.len() { + let progress = i as f32 / tokens.len() as f32; + + let prefix_bias = 10.0 + 90.0 * progress.powi(2); + + let prefix = &tokens[..i]; + builder.push(prefix, prefix_bias); + } } } - let trie = builder.build(); + let trie = builder.build(); Ok(BiasTrie { trie }) } + fn generate_tokenization_variants( + ctx: &WhisperContext, + word: &str, + ) -> Result>, crate::Error> { + let mut variants = Vec::new(); + + variants.push(ctx.tokenize(word, 99)?); + variants.push(ctx.tokenize(&format!(" {}", word), 99)?); + + let lower = word.to_lowercase(); + if lower != word { + variants.push(ctx.tokenize(&lower, 99)?); + variants.push(ctx.tokenize(&format!(" {}", lower), 99)?); + } + + let upper = word.to_uppercase(); + if upper != word { + variants.push(ctx.tokenize(&upper, 99)?); + } + + variants.push(ctx.tokenize(&format!("'{}", word), 99)?); + variants.push(ctx.tokenize(&format!("\"{}", word), 99)?); + + Ok(variants) + } + pub unsafe fn apply_bias_to_logits( &self, tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, @@ -44,29 +71,23 @@ impl BiasTrie { .map(|t| t.id) .collect(); - for suffix_len in 1..=std::cmp::min(5, current_tokens.len()) { + for suffix_len in 1..=std::cmp::min(10, current_tokens.len()) { let suffix = ¤t_tokens[current_tokens.len() - suffix_len..]; - let mut found_continuations = false; - for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) { - found_continuations = true; - let bias_value = *bias_value_ref; let full_sequence: Vec = full_sequence; if full_sequence.len() > suffix.len() { let next_token = full_sequence[suffix.len()]; let current_logit = *logits.offset(next_token as isize); - let new_logit = current_logit + bias_value.ln(); + + let boost = bias_value.ln() * 2.0; + let new_logit = current_logit + boost; *logits.offset(next_token as isize) = new_logit; } } - - if found_continuations { - break; - } } } }