diff --git a/Cargo.lock b/Cargo.lock index 285589ac8f..164bd63fb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11153,6 +11153,7 @@ dependencies = [ "tokio-stream", "tracing", "tracing-subscriber", + "ureq 2.12.1", "url", "ws", ] diff --git a/owhisper/owhisper-client/Cargo.toml b/owhisper/owhisper-client/Cargo.toml index fb9cd8b856..298f09e365 100644 --- a/owhisper/owhisper-client/Cargo.toml +++ b/owhisper/owhisper-client/Cargo.toml @@ -15,6 +15,7 @@ futures-util = { workspace = true } reqwest = { workspace = true, features = ["json", "multipart"] } tokio = { workspace = true } tokio-stream = { workspace = true } +ureq = { version = "2", features = ["json"] } bytes = { workspace = true } serde = { workspace = true } diff --git a/owhisper/owhisper-client/src/adapter/gladia/live.rs b/owhisper/owhisper-client/src/adapter/gladia/live.rs new file mode 100644 index 0000000000..ab01f9a4bf --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/gladia/live.rs @@ -0,0 +1,433 @@ +use std::collections::HashMap; +use std::sync::{Mutex, OnceLock}; + +use hypr_ws::client::Message; +use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse}; +use owhisper_interface::ListenParams; +use serde::{Deserialize, Serialize}; + +use super::GladiaAdapter; +use crate::adapter::parsing::WordBuilder; +use crate::adapter::RealtimeSttAdapter; + +fn session_channels() -> &'static Mutex> { + static SESSION_CHANNELS: OnceLock>> = OnceLock::new(); + SESSION_CHANNELS.get_or_init(|| Mutex::new(HashMap::new())) +} + +impl RealtimeSttAdapter for GladiaAdapter { + fn provider_name(&self) -> &'static str { + "gladia" + } + + fn supports_native_multichannel(&self) -> bool { + true + } + + fn build_ws_url(&self, api_base: &str, _params: &ListenParams, _channels: u8) -> url::Url { + let (mut url, existing_params) = Self::build_ws_url_from_base(api_base); + + if !existing_params.is_empty() { + let mut query_pairs = url.query_pairs_mut(); + for (key, value) in &existing_params { + query_pairs.append_pair(key, value); + } + } + + url + } + + fn build_ws_url_with_api_key( + &self, + api_base: &str, + params: &ListenParams, + channels: u8, + api_key: Option<&str>, + ) -> Option { + if let Some(proxy_result) = crate::adapter::build_proxy_ws_url(api_base) { + let (mut url, existing_params) = proxy_result; + if !existing_params.is_empty() { + let mut query_pairs = url.query_pairs_mut(); + for (key, value) in &existing_params { + query_pairs.append_pair(key, value); + } + } + return Some(url); + } + + let key = api_key?; + + let post_url = if api_base.is_empty() { + "https://api.gladia.io/v2/live".to_string() + } else { + let parsed: url::Url = api_base.parse().ok()?; + let host = parsed.host_str().unwrap_or("api.gladia.io"); + let scheme = if crate::adapter::is_local_host(host) { + "http" + } else { + "https" + }; + let host_with_port = match parsed.port() { + Some(port) => format!("{host}:{port}"), + None => host.to_string(), + }; + format!("{scheme}://{host_with_port}/v2/live") + }; + + let language_config = if params.languages.is_empty() { + None + } else { + Some(LanguageConfig { + languages: params + .languages + .iter() + .map(|l| l.iso639().code().to_string()) + .collect(), + }) + }; + + let custom_vocabulary = if params.keywords.is_empty() { + None + } else { + Some(params.keywords.clone()) + }; + + let body = GladiaConfig { + encoding: "wav/pcm", + sample_rate: params.sample_rate, + bit_depth: 16, + channels, + language_config, + custom_vocabulary, + messages_config: Some(MessagesConfig { + receive_partial_transcripts: true, + }), + }; + + let body_json = match serde_json::to_value(&body) { + Ok(v) => v, + Err(e) => { + tracing::error!(error = ?e, "gladia_init_serialize_failed"); + return None; + } + }; + + let resp = match ureq::post(&post_url) + .set("x-gladia-key", key) + .set("Content-Type", "application/json") + .send_json(body_json) + { + Ok(r) => r, + Err(e) => { + tracing::error!(error = ?e, "gladia_init_request_failed"); + return None; + } + }; + + let init: InitResponse = match resp.into_json() { + Ok(r) => r, + Err(e) => { + tracing::error!(error = ?e, "gladia_init_parse_failed"); + return None; + } + }; + + tracing::debug!(session_id = %init.id, url = %init.url, channels = channels, "gladia_session_initialized"); + + if let Ok(mut map) = session_channels().lock() { + map.insert(init.id.clone(), channels); + } + + url::Url::parse(&init.url).ok() + } + + fn build_auth_header(&self, _api_key: Option<&str>) -> Option<(&'static str, String)> { + None + } + + fn keep_alive_message(&self) -> Option { + None + } + + fn initial_message( + &self, + _api_key: Option<&str>, + _params: &ListenParams, + _channels: u8, + ) -> Option { + None + } + + fn finalize_message(&self) -> Message { + Message::Text(r#"{"type":"stop_recording"}"#.into()) + } + + fn parse_response(&self, raw: &str) -> Vec { + let msg: GladiaMessage = match serde_json::from_str(raw) { + Ok(m) => m, + Err(e) => { + tracing::warn!(error = ?e, raw = raw, "gladia_json_parse_failed"); + return vec![]; + } + }; + + match msg { + GladiaMessage::Transcript(transcript) => Self::parse_transcript(transcript), + GladiaMessage::StartSession { id } => { + tracing::debug!(session_id = %id, "gladia_session_started"); + vec![] + } + GladiaMessage::EndSession { id } => { + let channels = session_channels() + .lock() + .ok() + .and_then(|mut map| map.remove(&id)) + .unwrap_or_else(|| { + tracing::warn!(session_id = %id, "gladia_session_channels_not_found"); + 1 + }); + tracing::debug!(session_id = %id, channels = channels, "gladia_session_ended"); + vec![StreamResponse::TerminalResponse { + request_id: id, + created: String::new(), + duration: 0.0, + channels: channels.into(), + }] + } + GladiaMessage::SpeechStart { .. } => vec![], + GladiaMessage::SpeechEnd { .. } => vec![], + GladiaMessage::StartRecording { .. } => vec![], + GladiaMessage::EndRecording { .. } => vec![], + GladiaMessage::Error { message, code } => { + tracing::error!(error = %message, code = ?code, "gladia_error"); + vec![] + } + GladiaMessage::Unknown => { + tracing::debug!(raw = raw, "gladia_unknown_message"); + vec![] + } + } + } +} + +#[derive(Serialize)] +struct GladiaConfig<'a> { + encoding: &'a str, + sample_rate: u32, + bit_depth: u8, + channels: u8, + #[serde(skip_serializing_if = "Option::is_none")] + language_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + custom_vocabulary: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + messages_config: Option, +} + +#[derive(Serialize)] +struct LanguageConfig { + languages: Vec, +} + +#[derive(Serialize)] +struct MessagesConfig { + receive_partial_transcripts: bool, +} + +#[derive(Debug, Deserialize)] +struct InitResponse { + id: String, + url: String, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +enum GladiaMessage { + #[serde(rename = "transcript")] + Transcript(TranscriptMessage), + #[serde(rename = "start_session")] + StartSession { id: String }, + #[serde(rename = "end_session")] + EndSession { id: String }, + #[serde(rename = "speech_start")] + SpeechStart { + #[serde(default)] + session_id: Option, + }, + #[serde(rename = "speech_end")] + SpeechEnd { + #[serde(default)] + session_id: Option, + }, + #[serde(rename = "start_recording")] + StartRecording { + #[serde(default)] + session_id: Option, + }, + #[serde(rename = "end_recording")] + EndRecording { + #[serde(default)] + session_id: Option, + }, + #[serde(rename = "error")] + Error { + message: String, + #[serde(default)] + code: Option, + }, + #[serde(other)] + Unknown, +} + +#[derive(Debug, Deserialize)] +struct TranscriptMessage { + #[serde(default)] + session_id: String, + data: TranscriptData, +} + +#[derive(Debug, Deserialize)] +struct TranscriptData { + #[serde(default)] + id: String, + #[serde(default)] + is_final: bool, + utterance: Utterance, +} + +#[derive(Debug, Deserialize)] +struct Utterance { + #[serde(default)] + text: String, + #[serde(default)] + start: f64, + #[serde(default)] + end: f64, + #[serde(default)] + language: Option, + #[serde(default)] + channel: Option, + #[serde(default)] + words: Vec, +} + +#[derive(Debug, Deserialize)] +struct GladiaWord { + #[serde(default)] + word: String, + #[serde(default)] + start: f64, + #[serde(default)] + end: f64, + #[serde(default)] + confidence: f64, +} + +impl GladiaAdapter { + fn parse_transcript(msg: TranscriptMessage) -> Vec { + let session_id = msg.session_id; + let data = msg.data; + let utterance = data.utterance; + + tracing::debug!( + transcript = %utterance.text, + is_final = data.is_final, + channel = ?utterance.channel, + session_id = %session_id, + "gladia_transcript_received" + ); + + if utterance.text.is_empty() && utterance.words.is_empty() { + return vec![]; + } + + let is_final = data.is_final; + let speech_final = data.is_final; + let from_finalize = false; + + let words: Vec<_> = utterance + .words + .iter() + .map(|w| { + WordBuilder::new(&w.word) + .start(w.start) + .end(w.end) + .confidence(w.confidence) + .language(utterance.language.clone()) + .build() + }) + .collect(); + + let start = utterance.start; + let duration = utterance.end - utterance.start; + + let channel = Channel { + alternatives: vec![Alternatives { + transcript: utterance.text, + words, + confidence: 1.0, + languages: utterance.language.map(|l| vec![l]).unwrap_or_default(), + }], + }; + + let channel_idx = utterance.channel.unwrap_or(0); + let total_channels = session_channels() + .lock() + .ok() + .and_then(|map| map.get(&session_id).copied()) + .unwrap_or_else(|| { + // Fallback: infer from channel_idx if session not found + (channel_idx + 1).max(1) as u8 + }); + + vec![StreamResponse::TranscriptResponse { + is_final, + speech_final, + from_finalize, + start, + duration, + channel, + metadata: Metadata::default(), + channel_index: vec![channel_idx, total_channels as i32], + }] + } +} + +#[cfg(test)] +mod tests { + use super::GladiaAdapter; + use crate::test_utils::{run_dual_test, run_single_test}; + use crate::ListenClient; + + #[tokio::test] + #[ignore] + async fn test_build_single() { + let client = ListenClient::builder() + .adapter::() + .api_base("https://api.gladia.io") + .api_key(std::env::var("GLADIA_API_KEY").expect("GLADIA_API_KEY not set")) + .params(owhisper_interface::ListenParams { + languages: vec![hypr_language::ISO639::En.into()], + ..Default::default() + }) + .build_single(); + + run_single_test(client, "gladia").await; + } + + #[tokio::test] + #[ignore] + async fn test_build_dual() { + let client = ListenClient::builder() + .adapter::() + .api_base("https://api.gladia.io") + .api_key(std::env::var("GLADIA_API_KEY").expect("GLADIA_API_KEY not set")) + .params(owhisper_interface::ListenParams { + languages: vec![hypr_language::ISO639::En.into()], + ..Default::default() + }) + .build_dual(); + + run_dual_test(client, "gladia").await; + } +} diff --git a/owhisper/owhisper-client/src/adapter/gladia/mod.rs b/owhisper/owhisper-client/src/adapter/gladia/mod.rs new file mode 100644 index 0000000000..8fdd4f7f5c --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/gladia/mod.rs @@ -0,0 +1,101 @@ +mod live; + +pub(crate) const DEFAULT_API_HOST: &str = "api.gladia.io"; +pub(crate) const WS_PATH: &str = "/v2/live"; + +#[derive(Clone, Default)] +pub struct GladiaAdapter; + +impl GladiaAdapter { + pub fn is_supported_languages(_languages: &[hypr_language::Language]) -> bool { + true + } + + pub fn is_host(base_url: &str) -> bool { + super::host_matches(base_url, Self::is_gladia_host) + } + + pub(crate) fn is_gladia_host(host: &str) -> bool { + host.contains("gladia.io") + } + + pub(crate) fn build_ws_url_from_base(api_base: &str) -> (url::Url, Vec<(String, String)>) { + if api_base.is_empty() { + return ( + format!("wss://{}{}", DEFAULT_API_HOST, WS_PATH) + .parse() + .expect("invalid_default_ws_url"), + Vec::new(), + ); + } + + if let Some(proxy_result) = super::build_proxy_ws_url(api_base) { + return proxy_result; + } + + let parsed: url::Url = api_base.parse().expect("invalid_api_base"); + let existing_params = super::extract_query_params(&parsed); + + let host = parsed.host_str().unwrap_or(DEFAULT_API_HOST); + let scheme = if super::is_local_host(host) { + "ws" + } else { + "wss" + }; + let host_with_port = match parsed.port() { + Some(port) => format!("{host}:{port}"), + None => host.to_string(), + }; + + let url: url::Url = format!("{scheme}://{host_with_port}{WS_PATH}") + .parse() + .expect("invalid_ws_url"); + (url, existing_params) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_ws_url_from_base() { + let cases = [ + ("", "wss://api.gladia.io/v2/live", vec![]), + ( + "https://api.gladia.io", + "wss://api.gladia.io/v2/live", + vec![], + ), + ( + "https://api.gladia.io:8443", + "wss://api.gladia.io:8443/v2/live", + vec![], + ), + ( + "https://api.hyprnote.com?provider=gladia", + "wss://api.hyprnote.com/listen", + vec![("provider", "gladia")], + ), + ( + "http://localhost:8787/listen?provider=gladia", + "ws://localhost:8787/listen", + vec![("provider", "gladia")], + ), + ]; + + for (input, expected_url, expected_params) in cases { + let (url, params) = GladiaAdapter::build_ws_url_from_base(input); + assert_eq!(url.as_str(), expected_url, "input: {}", input); + assert_eq!( + params, + expected_params + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect::>(), + "input: {}", + input + ); + } + } +} diff --git a/owhisper/owhisper-client/src/adapter/mod.rs b/owhisper/owhisper-client/src/adapter/mod.rs index bfa00e9d36..5b04e6bffc 100644 --- a/owhisper/owhisper-client/src/adapter/mod.rs +++ b/owhisper/owhisper-client/src/adapter/mod.rs @@ -3,6 +3,7 @@ mod assemblyai; mod deepgram; mod deepgram_compat; mod fireworks; +mod gladia; mod owhisper; pub mod parsing; mod soniox; @@ -14,6 +15,7 @@ pub use argmax::*; pub use assemblyai::*; pub use deepgram::*; pub use fireworks::*; +pub use gladia::*; pub use soniox::*; use std::future::Future; @@ -36,6 +38,16 @@ pub trait RealtimeSttAdapter: Clone + Default + Send + Sync + 'static { fn build_ws_url(&self, api_base: &str, params: &ListenParams, channels: u8) -> url::Url; + fn build_ws_url_with_api_key( + &self, + api_base: &str, + params: &ListenParams, + channels: u8, + _api_key: Option<&str>, + ) -> Option { + Some(self.build_ws_url(api_base, params, channels)) + } + fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)>; fn keep_alive_message(&self) -> Option; diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index 9ef71e2160..18ecd158d3 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -72,7 +72,9 @@ impl ListenClientBuilder { fn build_request(&self, adapter: &A, channels: u8) -> hypr_ws::client::ClientRequestBuilder { let params = self.get_params(); let api_base = append_provider_param(self.get_api_base(), adapter.provider_name()); - let url = adapter.build_ws_url(&api_base, ¶ms, channels); + let url = adapter + .build_ws_url_with_api_key(&api_base, ¶ms, channels, self.api_key.as_deref()) + .unwrap_or_else(|| adapter.build_ws_url(&api_base, ¶ms, channels)); let uri = url.to_string().parse().unwrap(); let mut request = hypr_ws::client::ClientRequestBuilder::new(uri);