Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions owhisper/owhisper-client/src/adapter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod deepgram;
mod deepgram_compat;
mod fireworks;
mod gladia;
mod openai;
mod owhisper;
pub mod parsing;
mod soniox;
Expand All @@ -16,6 +17,7 @@ pub use assemblyai::*;
pub use deepgram::*;
pub use fireworks::*;
pub use gladia::*;
pub use openai::*;
pub use soniox::*;

use std::future::Future;
Expand Down Expand Up @@ -164,6 +166,7 @@ pub enum AdapterKind {
Fireworks,
Deepgram,
AssemblyAI,
OpenAI,
}

impl AdapterKind {
Expand All @@ -182,6 +185,8 @@ impl AdapterKind {
Self::Soniox
} else if FireworksAdapter::is_host(base_url) {
Self::Fireworks
} else if OpenAIAdapter::is_host(base_url) {
Self::OpenAI
} else {
Self::Deepgram
}
Expand Down
360 changes: 360 additions & 0 deletions owhisper/owhisper-client/src/adapter/openai/live.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
use hypr_ws::client::Message;
use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse};
use owhisper_interface::ListenParams;
use serde::{Deserialize, Serialize};

use super::OpenAIAdapter;
use crate::adapter::parsing::{calculate_time_span, WordBuilder};
use crate::adapter::RealtimeSttAdapter;

impl RealtimeSttAdapter for OpenAIAdapter {
fn provider_name(&self) -> &'static str {
"openai"
}

fn supports_native_multichannel(&self) -> bool {
false
}

fn build_ws_url(&self, api_base: &str, params: &ListenParams, _channels: u8) -> url::Url {
let (mut url, existing_params) =
Self::build_ws_url_from_base(api_base, params.model.as_deref());

if !existing_params.is_empty() {
let mut query_pairs = url.query_pairs_mut();
for (key, value) in &existing_params {
query_pairs.append_pair(key, value);
}
}

url
}

fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> {
api_key.map(|key| ("Authorization", format!("Bearer {}", key)))
}

fn keep_alive_message(&self) -> Option<Message> {
None
}

fn initial_message(
&self,
_api_key: Option<&str>,
params: &ListenParams,
_channels: u8,
) -> Option<Message> {
let language = params
.languages
.first()
.map(|l| l.iso639().code().to_string());

let model = params.model.as_deref().unwrap_or(super::DEFAULT_MODEL);

let session_config = SessionUpdateEvent {
event_type: "session.update".to_string(),
session: SessionConfig {
session_type: "transcription".to_string(),
audio: Some(AudioConfig {
input: Some(AudioInputConfig {
format: Some(AudioFormat {
format_type: "audio/pcm".to_string(),
rate: 24000,
}),
transcription: Some(TranscriptionConfig {
model: model.to_string(),
language,
}),
turn_detection: Some(TurnDetection {
detection_type: "server_vad".to_string(),
threshold: Some(0.5),
prefix_padding_ms: Some(300),
silence_duration_ms: Some(500),
}),
}),
}),
include: Some(vec!["item.input_audio_transcription.logprobs".to_string()]),
},
};

let json = serde_json::to_string(&session_config).ok()?;
Some(Message::Text(json.into()))
}

fn finalize_message(&self) -> Message {
let commit = InputAudioBufferCommit {
event_type: "input_audio_buffer.commit".to_string(),
};
Message::Text(serde_json::to_string(&commit).unwrap().into())
}

fn parse_response(&self, raw: &str) -> Vec<StreamResponse> {
let event: OpenAIEvent = match serde_json::from_str(raw) {
Ok(e) => e,
Err(e) => {
tracing::warn!(error = ?e, raw = raw, "openai_json_parse_failed");
return vec![];
}
};

match event {
OpenAIEvent::SessionCreated { session } => {
tracing::debug!(session_id = %session.id, "openai_session_created");
vec![]
}
OpenAIEvent::SessionUpdated { session } => {
tracing::debug!(session_id = %session.id, "openai_session_updated");
vec![]
}
OpenAIEvent::InputAudioBufferCommitted { item_id } => {
tracing::debug!(item_id = %item_id, "openai_audio_buffer_committed");
vec![]
}
OpenAIEvent::InputAudioBufferCleared => {
tracing::debug!("openai_audio_buffer_cleared");
vec![]
}
OpenAIEvent::ConversationItemInputAudioTranscriptionCompleted {
item_id,
content_index,
transcript,
} => {
tracing::debug!(
item_id = %item_id,
content_index = content_index,
transcript = %transcript,
"openai_transcription_completed"
);
Self::build_transcript_response(&transcript, true, true)
}
OpenAIEvent::ConversationItemInputAudioTranscriptionDelta {
item_id,
content_index,
delta,
} => {
tracing::debug!(
item_id = %item_id,
content_index = content_index,
delta = %delta,
"openai_transcription_delta"
);
Self::build_transcript_response(&delta, false, false)
}
OpenAIEvent::ConversationItemInputAudioTranscriptionFailed {
item_id, error, ..
} => {
tracing::error!(
item_id = %item_id,
error_type = %error.error_type,
error_message = %error.message,
"openai_transcription_failed"
);
vec![]
}
OpenAIEvent::Error { error } => {
tracing::error!(
error_type = %error.error_type,
error_message = %error.message,
"openai_error"
);
vec![]
}
OpenAIEvent::Unknown => {
tracing::debug!(raw = raw, "openai_unknown_event");
vec![]
}
}
}
}

#[derive(Debug, Serialize)]
struct SessionUpdateEvent {
#[serde(rename = "type")]
event_type: String,
session: SessionConfig,
}

#[derive(Debug, Serialize)]
struct SessionConfig {
#[serde(rename = "type")]
session_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
audio: Option<AudioConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
include: Option<Vec<String>>,
}

#[derive(Debug, Serialize)]
struct AudioConfig {
#[serde(skip_serializing_if = "Option::is_none")]
input: Option<AudioInputConfig>,
}

#[derive(Debug, Serialize)]
struct AudioInputConfig {
#[serde(skip_serializing_if = "Option::is_none")]
format: Option<AudioFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
transcription: Option<TranscriptionConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
turn_detection: Option<TurnDetection>,
}

#[derive(Debug, Serialize)]
struct AudioFormat {
#[serde(rename = "type")]
format_type: String,
rate: u32,
}

#[derive(Debug, Serialize)]
struct TranscriptionConfig {
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
language: Option<String>,
}

#[derive(Debug, Serialize)]
struct TurnDetection {
#[serde(rename = "type")]
detection_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
threshold: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
prefix_padding_ms: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
silence_duration_ms: Option<u32>,
}

#[derive(Debug, Serialize)]
struct InputAudioBufferCommit {
#[serde(rename = "type")]
event_type: String,
}

#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum OpenAIEvent {
#[serde(rename = "session.created")]
SessionCreated { session: SessionInfo },
#[serde(rename = "session.updated")]
SessionUpdated { session: SessionInfo },
#[serde(rename = "input_audio_buffer.committed")]
InputAudioBufferCommitted { item_id: String },
#[serde(rename = "input_audio_buffer.cleared")]
InputAudioBufferCleared,
#[serde(rename = "conversation.item.input_audio_transcription.completed")]
ConversationItemInputAudioTranscriptionCompleted {
item_id: String,
content_index: u32,
transcript: String,
},
#[serde(rename = "conversation.item.input_audio_transcription.delta")]
ConversationItemInputAudioTranscriptionDelta {
item_id: String,
content_index: u32,
delta: String,
},
#[serde(rename = "conversation.item.input_audio_transcription.failed")]
ConversationItemInputAudioTranscriptionFailed {
item_id: String,
content_index: u32,
error: OpenAIError,
},
#[serde(rename = "error")]
Error { error: OpenAIError },
#[serde(other)]
Unknown,
}

#[derive(Debug, Deserialize)]
struct SessionInfo {
id: String,
}

#[derive(Debug, Deserialize)]
struct OpenAIError {
#[serde(rename = "type")]
error_type: String,
message: String,
}

impl OpenAIAdapter {
fn build_transcript_response(
transcript: &str,
is_final: bool,
speech_final: bool,
) -> Vec<StreamResponse> {
if transcript.is_empty() {
return vec![];
}

let words: Vec<_> = transcript
.split_whitespace()
.map(|word| WordBuilder::new(word).confidence(1.0).build())
.collect();

let (start, duration) = calculate_time_span(&words);

let channel = Channel {
alternatives: vec![Alternatives {
transcript: transcript.to_string(),
words,
confidence: 1.0,
languages: vec![],
}],
};

vec![StreamResponse::TranscriptResponse {
is_final,
speech_final,
from_finalize: false,
start,
duration,
channel,
metadata: Metadata::default(),
channel_index: vec![0, 1],
}]
}
}

#[cfg(test)]
mod tests {
use super::OpenAIAdapter;
use crate::test_utils::{run_dual_test, run_single_test};
use crate::ListenClient;

#[tokio::test]
#[ignore]
async fn test_build_single() {
let client = ListenClient::builder()
.adapter::<OpenAIAdapter>()
.api_base("wss://api.openai.com")
.api_key(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"))
.params(owhisper_interface::ListenParams {
model: Some("gpt-4o-transcribe".to_string()),
languages: vec![hypr_language::ISO639::En.into()],
..Default::default()
})
.build_single();

run_single_test(client, "openai").await;
}

#[tokio::test]
#[ignore]
async fn test_build_dual() {
let client = ListenClient::builder()
.adapter::<OpenAIAdapter>()
.api_base("wss://api.openai.com")
.api_key(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"))
.params(owhisper_interface::ListenParams {
model: Some("gpt-4o-transcribe".to_string()),
languages: vec![hypr_language::ISO639::En.into()],
..Default::default()
})
.build_dual();

run_dual_test(client, "openai").await;
}
}
Loading