diff --git a/Cargo.lock b/Cargo.lock index 6e71a66b9..c53a6e529 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4541,6 +4541,15 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" +[[package]] +name = "fid-rs" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6956a1e60e2d1412b44b4169d44a03dae518f8583d3e10090c912c105e48447" +dependencies = [ + "rayon", +] + [[package]] name = "field-offset" version = "0.3.6" @@ -8607,6 +8616,15 @@ dependencies = [ "url", ] +[[package]] +name = "louds-rs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "936de6c22f08e7135a921f8ada907acd0d88880c4f42b5591f634b9f1dd8e07f" +dependencies = [ + "fid-rs", +] + [[package]] name = "lru" version = "0.12.5" @@ -16455,6 +16473,15 @@ dependencies = [ "petgraph 0.6.5", ] +[[package]] +name = "trie-rs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f88f4b0a1ebd6c3d16be3e45eb0e8089372ccadd88849b7ca162ba64b5e6f6" +dependencies = [ + "louds-rs", +] + [[package]] name = "trim-in-place" version = "0.1.7" @@ -17495,6 +17522,7 @@ name = "whisper-local" version = "0.1.0" dependencies = [ "audio-utils", + "criterion", "dasp", "data", "dirs 6.0.0", @@ -17509,6 +17537,7 @@ dependencies = [ "thiserror 2.0.16", "tokio", "tracing", + "trie-rs", "whisper", "whisper-rs", ] @@ -17525,8 +17554,7 @@ dependencies = [ [[package]] name = "whisper-rs" version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4a5eb3a2a84d3adfb2e84b3ae783ee73428676582b2c91cb66c3091e737256" +source = "git+https://codeberg.org/tazz4843/whisper-rs?rev=3e6d3da#3e6d3da162146e8697fc39c4a81e66a339f14bc6" dependencies = [ "libc", "tracing", @@ -17536,8 +17564,7 @@ dependencies = [ [[package]] name = "whisper-rs-sys" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d0927ebac46387e6f3f705cc007c9ea40885f6d924071a918f1ee8b829cf5cd" +source = "git+https://codeberg.org/tazz4843/whisper-rs?rev=3e6d3da#3e6d3da162146e8697fc39c4a81e66a339f14bc6" dependencies = [ "bindgen 0.71.1", "cfg-if", diff --git a/crates/transcribe-whisper-local/src/service/streaming.rs b/crates/transcribe-whisper-local/src/service/streaming.rs index 373d3ab3b..d5574a27e 100644 --- a/crates/transcribe-whisper-local/src/service/streaming.rs +++ b/crates/transcribe-whisper-local/src/service/streaming.rs @@ -90,15 +90,18 @@ where } }; + let languages = params + .languages + .iter() + .filter_map(|lang| lang.clone().try_into().ok()) + .collect::>(); + + let vocabulary = params.vocabulary.clone(); + let model = match hypr_whisper_local::Whisper::builder() .model_path(model_path.to_str().unwrap()) - .languages( - params - .languages - .iter() - .filter_map(|lang| lang.clone().try_into().ok()) - .collect::>(), - ) + .languages(languages) + .vocabulary(vocabulary) .build() { Ok(model) => model, diff --git a/crates/whisper-local/Cargo.toml b/crates/whisper-local/Cargo.toml index 90c18d982..2134cc8ed 100644 --- a/crates/whisper-local/Cargo.toml +++ b/crates/whisper-local/Cargo.toml @@ -17,10 +17,15 @@ openmp = ["whisper-rs/openmp"] [dev-dependencies] hypr-data = { workspace = true } +criterion = { workspace = true } dirs = { workspace = true } futures-util = { workspace = true } tokio = { workspace = true } +[[bench]] +name = "whisper_transcription" +harness = false + [dependencies] hypr-audio-utils = { workspace = true } hypr-whisper = { workspace = true } @@ -28,7 +33,7 @@ hypr-whisper = { workspace = true } dasp = { workspace = true } kalosm-sound = { workspace = true, default-features = false } rodio = { workspace = true } -whisper-rs = { version = "0.15.0", features = ["raw-api", "tracing_backend"] } +whisper-rs = { git = "https://codeberg.org/tazz4843/whisper-rs", rev = "3e6d3da", features = ["raw-api", "tracing_backend"] } futures-util = { workspace = true } tracing = { workspace = true } @@ -37,6 +42,7 @@ serde = { workspace = true } serde_json = { workspace = true } specta = { workspace = true, features = ["derive"] } thiserror = { workspace = true } +trie-rs = "0.4.2" lazy_static = { workspace = true } regex = { workspace = true } diff --git a/crates/whisper-local/benches/whisper_transcription.rs b/crates/whisper-local/benches/whisper_transcription.rs new file mode 100644 index 000000000..cf65a4c1f --- /dev/null +++ b/crates/whisper-local/benches/whisper_transcription.rs @@ -0,0 +1,88 @@ +use std::hint::black_box; +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, Criterion}; +use hypr_whisper::Language; +use whisper_local::Whisper; + +fn benchmark_whisper_transcription(c: &mut Criterion) { + let audio: Vec = hypr_data::english_1::AUDIO + .chunks_exact(2) + .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / 32768.0) + .collect(); + + let model_path = concat!(env!("CARGO_MANIFEST_DIR"), "/model.bin"); + + let mut whisper_without_vocab = Whisper::builder() + .model_path(model_path) + .languages(vec![Language::En]) + .build() + .unwrap(); + + let mut whisper_with_vocab = Whisper::builder() + .model_path(model_path) + .languages(vec![Language::En]) + .vocabulary( + vec![ + "profound", + "acquire", + "complementary", + "deeply", + "repositories", + "brilliant", + "pockets", + "thread", + "stumbling", + "stumble", + "communities", + "invested", + "undergrad", + "Googleable", + "exploring", + "neuroscientist", + "psychology", + "engineering", + "researcher", + "thinker", + "skill", + "invest", + "solved", + "entire", + "especially", + "actually", + "often", + "already", + "important", + "definitely", + "much", + ] + .into_iter() + .map(|s| s.into()) + .collect(), + ) + .build() + .unwrap(); + + let mut group = c.benchmark_group("whisper_comparison"); + group.measurement_time(Duration::from_secs(100)); + group.sample_size(10); + + group.bench_function("without_vocab", |b| { + b.iter(|| { + let segments = whisper_without_vocab.transcribe(black_box(&audio)).unwrap(); + black_box(segments) + }) + }); + + group.bench_function("with_vocab", |b| { + b.iter(|| { + let segments = whisper_with_vocab.transcribe(black_box(&audio)).unwrap(); + black_box(segments) + }) + }); + + group.finish(); +} + +criterion_group!(benches, benchmark_whisper_transcription); +criterion_main!(benches); diff --git a/crates/whisper-local/src/bias.rs b/crates/whisper-local/src/bias.rs new file mode 100644 index 000000000..d884b59b9 --- /dev/null +++ b/crates/whisper-local/src/bias.rs @@ -0,0 +1,93 @@ +use trie_rs::map::{Trie, TrieBuilder}; +use whisper_rs::{WhisperContext, WhisperTokenId}; + +#[derive(Clone)] +pub struct BiasTrie { + trie: Trie, +} + +impl BiasTrie { + pub fn new(ctx: &WhisperContext, custom_vocab: &[&str]) -> Result { + let mut builder = TrieBuilder::new(); + + for word in custom_vocab { + let variants = Self::generate_tokenization_variants(ctx, word)?; + + for tokens in variants { + for i in 1..=tokens.len() { + let progress = i as f32 / tokens.len() as f32; + + let prefix_bias = 10.0 + 90.0 * progress.powi(2); + + let prefix = &tokens[..i]; + builder.push(prefix, prefix_bias); + } + } + } + + let trie = builder.build(); + Ok(BiasTrie { trie }) + } + + fn generate_tokenization_variants( + ctx: &WhisperContext, + word: &str, + ) -> Result>, crate::Error> { + let mut variants = Vec::new(); + + variants.push(ctx.tokenize(word, 99)?); + variants.push(ctx.tokenize(&format!(" {}", word), 99)?); + + let lower = word.to_lowercase(); + if lower != word { + variants.push(ctx.tokenize(&lower, 99)?); + variants.push(ctx.tokenize(&format!(" {}", lower), 99)?); + } + + let upper = word.to_uppercase(); + if upper != word { + variants.push(ctx.tokenize(&upper, 99)?); + } + + variants.push(ctx.tokenize(&format!("'{}", word), 99)?); + variants.push(ctx.tokenize(&format!("\"{}", word), 99)?); + + Ok(variants) + } + + pub unsafe fn apply_bias_to_logits( + &self, + tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, + n_tokens: std::os::raw::c_int, + logits: *mut f32, + ) { + if tokens.is_null() || n_tokens <= 0 { + return; + } + + let current_tokens: Vec = + std::slice::from_raw_parts(tokens, n_tokens as usize) + .iter() + .map(|t| t.id) + .collect(); + + for suffix_len in 1..=std::cmp::min(10, current_tokens.len()) { + let suffix = ¤t_tokens[current_tokens.len() - suffix_len..]; + + for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) { + let bias_value = *bias_value_ref; + let full_sequence: Vec = full_sequence; + + if full_sequence.len() > suffix.len() { + let next_token = full_sequence[suffix.len()]; + let current_logit = *logits.offset(next_token as isize); + + let boost = bias_value.ln() * 2.0; + let new_logit = current_logit + boost; + + *logits.offset(next_token as isize) = new_logit; + } + } + } + } +} diff --git a/crates/whisper-local/src/lib.rs b/crates/whisper-local/src/lib.rs index 057d1e756..d29092a58 100644 --- a/crates/whisper-local/src/lib.rs +++ b/crates/whisper-local/src/lib.rs @@ -9,6 +9,9 @@ pub use model::*; mod error; pub use error::*; +mod bias; +use bias::*; + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] pub struct GgmlBackend { pub kind: String, diff --git a/crates/whisper-local/src/model.rs b/crates/whisper-local/src/model.rs index 4f8a0c5b6..f478086c8 100644 --- a/crates/whisper-local/src/model.rs +++ b/crates/whisper-local/src/model.rs @@ -8,6 +8,7 @@ use whisper_rs::{ WhisperTokenId, }; +use crate::BiasTrie; use hypr_whisper::Language; lazy_static! { @@ -18,6 +19,7 @@ lazy_static! { pub struct WhisperBuilder { model_path: Option, languages: Option>, + vocabulary: Option>, } impl WhisperBuilder { @@ -31,6 +33,11 @@ impl WhisperBuilder { self } + pub fn vocabulary(mut self, vocabulary: Vec) -> Self { + self.vocabulary = Some(vocabulary); + self + } + pub fn build(self) -> Result { unsafe { Self::suppress_log() }; @@ -52,11 +59,19 @@ impl WhisperBuilder { let state = ctx.create_state()?; let token_beg = ctx.token_beg(); + let bias_trie = { + let custom_vocab = self.vocabulary.unwrap_or(vec!["Hyprnote".to_string()]); + + let custom_vocab_refs: Vec<&str> = custom_vocab.iter().map(|s| s.as_str()).collect(); + BiasTrie::new(&ctx, &custom_vocab_refs)? + }; + Ok(Whisper { languages: self.languages.unwrap_or_default(), dynamic_prompt: "".to_string(), state, token_beg, + bias_trie, }) } @@ -76,6 +91,7 @@ pub struct Whisper { dynamic_prompt: String, state: WhisperState, token_beg: WhisperTokenId, + bias_trie: BiasTrie, } impl Whisper { @@ -93,7 +109,7 @@ impl Whisper { let token_beg = self.token_beg; let language = self.get_language(audio)?; - let params = { + let mut params = { let mut p = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); let parts = [self.dynamic_prompt.trim()]; @@ -108,10 +124,6 @@ impl Whisper { p.set_initial_prompt(&initial_prompt); - unsafe { - Self::suppress_beg(&mut p, &token_beg); - } - p.set_no_timestamps(true); p.set_token_timestamps(false); p.set_split_on_word(true); @@ -130,6 +142,8 @@ impl Whisper { p }; + let _guard = unsafe { Self::set_logit_filter(&mut params, &token_beg, &self.bias_trie) }; + self.state.full(params, &audio[..])?; let num_segments = self.state.full_n_segments(); @@ -236,12 +250,21 @@ impl Whisper { .collect() } - unsafe fn suppress_beg(params: &mut FullParams, token_beg: &WhisperTokenId) { + unsafe fn set_logit_filter( + params: &mut FullParams, + token_beg: &WhisperTokenId, + bias_trie: &BiasTrie, + ) -> LogitFilterGuard { + let context = Box::new(Context { + token_beg: *token_beg, + bias_trie: bias_trie.clone(), + }); + unsafe extern "C" fn logits_filter_callback( _ctx: *mut whisper_rs::whisper_rs_sys::whisper_context, _state: *mut whisper_rs::whisper_rs_sys::whisper_state, - _tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, - _n_tokens: std::os::raw::c_int, + tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, + n_tokens: std::os::raw::c_int, logits: *mut f32, user_data: *mut std::os::raw::c_void, ) { @@ -249,14 +272,40 @@ impl Whisper { return; } - let token_beg_id = *(user_data as *const WhisperTokenId); - *logits.offset(token_beg_id as isize) = f32::NEG_INFINITY; + let context = &*(user_data as *const Context); + + *logits.offset(context.token_beg as isize) = f32::NEG_INFINITY; + + context + .bias_trie + .apply_bias_to_logits(tokens, n_tokens, logits); } + let context_ptr = Box::into_raw(context) as *mut std::ffi::c_void; + params.set_filter_logits_callback(Some(logits_filter_callback)); - params.set_filter_logits_callback_user_data( - token_beg as *const WhisperTokenId as *mut std::ffi::c_void, - ); + params.set_filter_logits_callback_user_data(context_ptr); + + LogitFilterGuard { context_ptr } + } +} + +struct Context { + token_beg: WhisperTokenId, + bias_trie: BiasTrie, +} + +struct LogitFilterGuard { + context_ptr: *mut std::ffi::c_void, +} + +impl Drop for LogitFilterGuard { + fn drop(&mut self) { + if !self.context_ptr.is_null() { + unsafe { + let _ = Box::from_raw(self.context_ptr as *mut Context); + } + } } } @@ -309,6 +358,15 @@ mod tests { fn test_whisper() { let mut whisper = Whisper::builder() .model_path(concat!(env!("CARGO_MANIFEST_DIR"), "/model.bin")) + .vocabulary( + vec![ + "Google", "should", "people", "question", "learning", "research", "problem", + "like", "actually", + ] + .into_iter() + .map(|s| s.into()) + .collect(), + ) .build() .unwrap(); @@ -318,7 +376,15 @@ mod tests { .collect(); let segments = whisper.transcribe(&audio).unwrap(); - println!("segments: {:#?}", segments); assert!(segments.len() > 0); + + println!( + "{}", + segments + .iter() + .map(|s| s.text.clone()) + .collect::>() + .join(" ") + ); } } diff --git a/owhisper/owhisper-interface/src/lib.rs b/owhisper/owhisper-interface/src/lib.rs index c7f0c2c29..213956aa5 100644 --- a/owhisper/owhisper-interface/src/lib.rs +++ b/owhisper/owhisper-interface/src/lib.rs @@ -125,6 +125,8 @@ common_derives! { // https://docs.rs/axum-extra/0.10.1/axum_extra/extract/struct.Query.html#example-1 #[serde(default)] pub languages: Vec, + #[serde(default)] + pub vocabulary: Vec, pub redemption_time_ms: Option, } } @@ -135,6 +137,7 @@ impl Default for ListenParams { model: None, channels: 1, languages: vec![], + vocabulary: vec![], redemption_time_ms: None, } } diff --git a/plugins/listener/src/fsm.rs b/plugins/listener/src/fsm.rs index e6d09f6c8..880ba5624 100644 --- a/plugins/listener/src/fsm.rs +++ b/plugins/listener/src/fsm.rs @@ -209,7 +209,7 @@ impl Session { let user_id = self.app.db_user_id().await?.unwrap(); self.session_id = Some(session_id.clone()); - let (record, languages) = { + let (record, languages, vocabulary) = { let config = self.app.db_get_config(&user_id).await?; let record = config @@ -221,7 +221,12 @@ impl Session { |c| c.general.spoken_languages.clone(), ); - (record, languages) + let vocabulary = config.as_ref().map_or_else( + || vec!["Hyprnote".to_string()], + |c| c.general.jargons.clone(), + ); + + (record, languages, vocabulary) }; let session = self @@ -243,8 +248,14 @@ impl Session { self.speaker_muted_rx = Some(speaker_muted_rx_main.clone()); self.session_state_tx = Some(session_state_tx); - let listen_client = - setup_listen_client(&self.app, languages, session_id == onboarding_session_id).await?; + let listen_client = setup_listen_client( + &self.app, + vocabulary, + languages, + session_id == onboarding_session_id, + ) + .await?; + let mic_sample_stream = { let mut input = hypr_audio::AudioInput::from_mic(self.mic_device_name.clone())?; input.stream() @@ -603,6 +614,7 @@ impl Session { async fn setup_listen_client( app: &tauri::AppHandle, + vocabulary: Vec, languages: Vec, is_onboarding: bool, ) -> Result { @@ -615,6 +627,7 @@ async fn setup_listen_client( .api_base(conn.base_url) .api_key(conn.api_key.unwrap_or_default()) .params(owhisper_interface::ListenParams { + vocabulary, languages, redemption_time_ms: Some(if is_onboarding { 60 } else { 400 }), ..Default::default()