diff --git a/owhisper/owhisper-client/src/adapter/mod.rs b/owhisper/owhisper-client/src/adapter/mod.rs index 5b04e6bffc..8ba491881d 100644 --- a/owhisper/owhisper-client/src/adapter/mod.rs +++ b/owhisper/owhisper-client/src/adapter/mod.rs @@ -4,6 +4,7 @@ mod deepgram; mod deepgram_compat; mod fireworks; mod gladia; +mod openai; mod owhisper; pub mod parsing; mod soniox; @@ -16,6 +17,7 @@ pub use assemblyai::*; pub use deepgram::*; pub use fireworks::*; pub use gladia::*; +pub use openai::*; pub use soniox::*; use std::future::Future; @@ -164,6 +166,7 @@ pub enum AdapterKind { Fireworks, Deepgram, AssemblyAI, + OpenAI, } impl AdapterKind { @@ -182,6 +185,8 @@ impl AdapterKind { Self::Soniox } else if FireworksAdapter::is_host(base_url) { Self::Fireworks + } else if OpenAIAdapter::is_host(base_url) { + Self::OpenAI } else { Self::Deepgram } diff --git a/owhisper/owhisper-client/src/adapter/openai/live.rs b/owhisper/owhisper-client/src/adapter/openai/live.rs new file mode 100644 index 0000000000..55189db4ad --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/openai/live.rs @@ -0,0 +1,360 @@ +use hypr_ws::client::Message; +use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse}; +use owhisper_interface::ListenParams; +use serde::{Deserialize, Serialize}; + +use super::OpenAIAdapter; +use crate::adapter::parsing::{calculate_time_span, WordBuilder}; +use crate::adapter::RealtimeSttAdapter; + +impl RealtimeSttAdapter for OpenAIAdapter { + fn provider_name(&self) -> &'static str { + "openai" + } + + fn supports_native_multichannel(&self) -> bool { + false + } + + fn build_ws_url(&self, api_base: &str, params: &ListenParams, _channels: u8) -> url::Url { + let (mut url, existing_params) = + Self::build_ws_url_from_base(api_base, params.model.as_deref()); + + if !existing_params.is_empty() { + let mut query_pairs = url.query_pairs_mut(); + for (key, value) in &existing_params { + query_pairs.append_pair(key, value); + } + } + + url + } + + fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> { + api_key.map(|key| ("Authorization", format!("Bearer {}", key))) + } + + fn keep_alive_message(&self) -> Option { + None + } + + fn initial_message( + &self, + _api_key: Option<&str>, + params: &ListenParams, + _channels: u8, + ) -> Option { + let language = params + .languages + .first() + .map(|l| l.iso639().code().to_string()); + + let model = params.model.as_deref().unwrap_or(super::DEFAULT_MODEL); + + let session_config = SessionUpdateEvent { + event_type: "session.update".to_string(), + session: SessionConfig { + session_type: "transcription".to_string(), + audio: Some(AudioConfig { + input: Some(AudioInputConfig { + format: Some(AudioFormat { + format_type: "audio/pcm".to_string(), + rate: 24000, + }), + transcription: Some(TranscriptionConfig { + model: model.to_string(), + language, + }), + turn_detection: Some(TurnDetection { + detection_type: "server_vad".to_string(), + threshold: Some(0.5), + prefix_padding_ms: Some(300), + silence_duration_ms: Some(500), + }), + }), + }), + include: Some(vec!["item.input_audio_transcription.logprobs".to_string()]), + }, + }; + + let json = serde_json::to_string(&session_config).ok()?; + Some(Message::Text(json.into())) + } + + fn finalize_message(&self) -> Message { + let commit = InputAudioBufferCommit { + event_type: "input_audio_buffer.commit".to_string(), + }; + Message::Text(serde_json::to_string(&commit).unwrap().into()) + } + + fn parse_response(&self, raw: &str) -> Vec { + let event: OpenAIEvent = match serde_json::from_str(raw) { + Ok(e) => e, + Err(e) => { + tracing::warn!(error = ?e, raw = raw, "openai_json_parse_failed"); + return vec![]; + } + }; + + match event { + OpenAIEvent::SessionCreated { session } => { + tracing::debug!(session_id = %session.id, "openai_session_created"); + vec![] + } + OpenAIEvent::SessionUpdated { session } => { + tracing::debug!(session_id = %session.id, "openai_session_updated"); + vec![] + } + OpenAIEvent::InputAudioBufferCommitted { item_id } => { + tracing::debug!(item_id = %item_id, "openai_audio_buffer_committed"); + vec![] + } + OpenAIEvent::InputAudioBufferCleared => { + tracing::debug!("openai_audio_buffer_cleared"); + vec![] + } + OpenAIEvent::ConversationItemInputAudioTranscriptionCompleted { + item_id, + content_index, + transcript, + } => { + tracing::debug!( + item_id = %item_id, + content_index = content_index, + transcript = %transcript, + "openai_transcription_completed" + ); + Self::build_transcript_response(&transcript, true, true) + } + OpenAIEvent::ConversationItemInputAudioTranscriptionDelta { + item_id, + content_index, + delta, + } => { + tracing::debug!( + item_id = %item_id, + content_index = content_index, + delta = %delta, + "openai_transcription_delta" + ); + Self::build_transcript_response(&delta, false, false) + } + OpenAIEvent::ConversationItemInputAudioTranscriptionFailed { + item_id, error, .. + } => { + tracing::error!( + item_id = %item_id, + error_type = %error.error_type, + error_message = %error.message, + "openai_transcription_failed" + ); + vec![] + } + OpenAIEvent::Error { error } => { + tracing::error!( + error_type = %error.error_type, + error_message = %error.message, + "openai_error" + ); + vec![] + } + OpenAIEvent::Unknown => { + tracing::debug!(raw = raw, "openai_unknown_event"); + vec![] + } + } + } +} + +#[derive(Debug, Serialize)] +struct SessionUpdateEvent { + #[serde(rename = "type")] + event_type: String, + session: SessionConfig, +} + +#[derive(Debug, Serialize)] +struct SessionConfig { + #[serde(rename = "type")] + session_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + audio: Option, + #[serde(skip_serializing_if = "Option::is_none")] + include: Option>, +} + +#[derive(Debug, Serialize)] +struct AudioConfig { + #[serde(skip_serializing_if = "Option::is_none")] + input: Option, +} + +#[derive(Debug, Serialize)] +struct AudioInputConfig { + #[serde(skip_serializing_if = "Option::is_none")] + format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + transcription: Option, + #[serde(skip_serializing_if = "Option::is_none")] + turn_detection: Option, +} + +#[derive(Debug, Serialize)] +struct AudioFormat { + #[serde(rename = "type")] + format_type: String, + rate: u32, +} + +#[derive(Debug, Serialize)] +struct TranscriptionConfig { + model: String, + #[serde(skip_serializing_if = "Option::is_none")] + language: Option, +} + +#[derive(Debug, Serialize)] +struct TurnDetection { + #[serde(rename = "type")] + detection_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + threshold: Option, + #[serde(skip_serializing_if = "Option::is_none")] + prefix_padding_ms: Option, + #[serde(skip_serializing_if = "Option::is_none")] + silence_duration_ms: Option, +} + +#[derive(Debug, Serialize)] +struct InputAudioBufferCommit { + #[serde(rename = "type")] + event_type: String, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +enum OpenAIEvent { + #[serde(rename = "session.created")] + SessionCreated { session: SessionInfo }, + #[serde(rename = "session.updated")] + SessionUpdated { session: SessionInfo }, + #[serde(rename = "input_audio_buffer.committed")] + InputAudioBufferCommitted { item_id: String }, + #[serde(rename = "input_audio_buffer.cleared")] + InputAudioBufferCleared, + #[serde(rename = "conversation.item.input_audio_transcription.completed")] + ConversationItemInputAudioTranscriptionCompleted { + item_id: String, + content_index: u32, + transcript: String, + }, + #[serde(rename = "conversation.item.input_audio_transcription.delta")] + ConversationItemInputAudioTranscriptionDelta { + item_id: String, + content_index: u32, + delta: String, + }, + #[serde(rename = "conversation.item.input_audio_transcription.failed")] + ConversationItemInputAudioTranscriptionFailed { + item_id: String, + content_index: u32, + error: OpenAIError, + }, + #[serde(rename = "error")] + Error { error: OpenAIError }, + #[serde(other)] + Unknown, +} + +#[derive(Debug, Deserialize)] +struct SessionInfo { + id: String, +} + +#[derive(Debug, Deserialize)] +struct OpenAIError { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +impl OpenAIAdapter { + fn build_transcript_response( + transcript: &str, + is_final: bool, + speech_final: bool, + ) -> Vec { + if transcript.is_empty() { + return vec![]; + } + + let words: Vec<_> = transcript + .split_whitespace() + .map(|word| WordBuilder::new(word).confidence(1.0).build()) + .collect(); + + let (start, duration) = calculate_time_span(&words); + + let channel = Channel { + alternatives: vec![Alternatives { + transcript: transcript.to_string(), + words, + confidence: 1.0, + languages: vec![], + }], + }; + + vec![StreamResponse::TranscriptResponse { + is_final, + speech_final, + from_finalize: false, + start, + duration, + channel, + metadata: Metadata::default(), + channel_index: vec![0, 1], + }] + } +} + +#[cfg(test)] +mod tests { + use super::OpenAIAdapter; + use crate::test_utils::{run_dual_test, run_single_test}; + use crate::ListenClient; + + #[tokio::test] + #[ignore] + async fn test_build_single() { + let client = ListenClient::builder() + .adapter::() + .api_base("wss://api.openai.com") + .api_key(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")) + .params(owhisper_interface::ListenParams { + model: Some("gpt-4o-transcribe".to_string()), + languages: vec![hypr_language::ISO639::En.into()], + ..Default::default() + }) + .build_single(); + + run_single_test(client, "openai").await; + } + + #[tokio::test] + #[ignore] + async fn test_build_dual() { + let client = ListenClient::builder() + .adapter::() + .api_base("wss://api.openai.com") + .api_key(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")) + .params(owhisper_interface::ListenParams { + model: Some("gpt-4o-transcribe".to_string()), + languages: vec![hypr_language::ISO639::En.into()], + ..Default::default() + }) + .build_dual(); + + run_dual_test(client, "openai").await; + } +} diff --git a/owhisper/owhisper-client/src/adapter/openai/mod.rs b/owhisper/owhisper-client/src/adapter/openai/mod.rs new file mode 100644 index 0000000000..06fe96d08a --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/openai/mod.rs @@ -0,0 +1,110 @@ +mod live; + +pub(crate) const DEFAULT_WS_HOST: &str = "api.openai.com"; +pub(crate) const WS_PATH: &str = "/v1/realtime"; +pub(crate) const DEFAULT_MODEL: &str = "gpt-4o-transcribe"; + +#[derive(Clone, Default)] +pub struct OpenAIAdapter; + +impl OpenAIAdapter { + pub fn is_supported_languages(_languages: &[hypr_language::Language]) -> bool { + true + } + + pub fn is_host(base_url: &str) -> bool { + super::host_matches(base_url, Self::is_openai_host) + } + + pub(crate) fn is_openai_host(host: &str) -> bool { + host.contains("openai.com") + } + + pub(crate) fn build_ws_url_from_base( + api_base: &str, + model: Option<&str>, + ) -> (url::Url, Vec<(String, String)>) { + if api_base.is_empty() { + let model = model.unwrap_or(DEFAULT_MODEL); + return ( + format!("wss://{}{}", DEFAULT_WS_HOST, WS_PATH) + .parse() + .expect("invalid_default_ws_url"), + vec![("model".to_string(), model.to_string())], + ); + } + + if let Some(proxy_result) = super::build_proxy_ws_url(api_base) { + return proxy_result; + } + + let parsed: url::Url = api_base.parse().expect("invalid_api_base"); + let mut existing_params = super::extract_query_params(&parsed); + + if !existing_params.iter().any(|(k, _)| k == "model") { + let model = model.unwrap_or(DEFAULT_MODEL); + existing_params.push(("model".to_string(), model.to_string())); + } + + let host = parsed.host_str().unwrap_or(DEFAULT_WS_HOST); + let mut url: url::Url = format!("wss://{}{}", host, WS_PATH) + .parse() + .expect("invalid_ws_url"); + + super::set_scheme_from_host(&mut url); + + (url, existing_params) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_ws_url_from_base_empty() { + let (url, params) = OpenAIAdapter::build_ws_url_from_base("", None); + assert_eq!(url.as_str(), "wss://api.openai.com/v1/realtime"); + assert_eq!( + params, + vec![("model".to_string(), "gpt-4o-transcribe".to_string())] + ); + } + + #[test] + fn test_build_ws_url_from_base_with_model() { + let (url, params) = + OpenAIAdapter::build_ws_url_from_base("", Some("gpt-4o-mini-realtime-preview")); + assert_eq!(url.as_str(), "wss://api.openai.com/v1/realtime"); + assert_eq!( + params, + vec![( + "model".to_string(), + "gpt-4o-mini-realtime-preview".to_string() + )] + ); + } + + #[test] + fn test_build_ws_url_from_base_proxy() { + let (url, params) = + OpenAIAdapter::build_ws_url_from_base("https://api.hyprnote.com?provider=openai", None); + assert_eq!(url.as_str(), "wss://api.hyprnote.com/listen"); + assert_eq!(params, vec![("provider".to_string(), "openai".to_string())]); + } + + #[test] + fn test_build_ws_url_from_base_localhost() { + let (url, params) = + OpenAIAdapter::build_ws_url_from_base("http://localhost:8787?provider=openai", None); + assert_eq!(url.as_str(), "ws://localhost:8787/listen"); + assert_eq!(params, vec![("provider".to_string(), "openai".to_string())]); + } + + #[test] + fn test_is_openai_host() { + assert!(OpenAIAdapter::is_openai_host("api.openai.com")); + assert!(OpenAIAdapter::is_openai_host("openai.com")); + assert!(!OpenAIAdapter::is_openai_host("api.deepgram.com")); + } +}