Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions owhisper/owhisper-client/src/adapter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod deepgram;
mod deepgram_compat;
mod fireworks;
mod gladia;
mod openai;
mod owhisper;
pub mod parsing;
mod soniox;
Expand All @@ -16,6 +17,7 @@ pub use assemblyai::*;
pub use deepgram::*;
pub use fireworks::*;
pub use gladia::*;
pub use openai::*;
pub use soniox::*;

use std::future::Future;
Expand Down Expand Up @@ -164,6 +166,7 @@ pub enum AdapterKind {
Fireworks,
Deepgram,
AssemblyAI,
OpenAI,
}

impl AdapterKind {
Expand All @@ -182,6 +185,8 @@ impl AdapterKind {
Self::Soniox
} else if FireworksAdapter::is_host(base_url) {
Self::Fireworks
} else if OpenAIAdapter::is_host(base_url) {
Self::OpenAI
} else {
Self::Deepgram
}
Expand Down
358 changes: 358 additions & 0 deletions owhisper/owhisper-client/src/adapter/openai/live.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
use hypr_ws::client::Message;
use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse};
use owhisper_interface::ListenParams;
use serde::{Deserialize, Serialize};

use super::OpenAIAdapter;
use crate::adapter::parsing::{calculate_time_span, WordBuilder};
use crate::adapter::RealtimeSttAdapter;

impl RealtimeSttAdapter for OpenAIAdapter {
fn provider_name(&self) -> &'static str {
"openai"
}

fn supports_native_multichannel(&self) -> bool {
false
}

fn build_ws_url(&self, api_base: &str, params: &ListenParams, _channels: u8) -> url::Url {
let (mut url, existing_params) =
Self::build_ws_url_from_base(api_base, params.model.as_deref());

if !existing_params.is_empty() {
let mut query_pairs = url.query_pairs_mut();
for (key, value) in &existing_params {
query_pairs.append_pair(key, value);
}
}

url
}

fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> {
api_key.map(|key| ("Authorization", format!("Bearer {}", key)))
}

fn keep_alive_message(&self) -> Option<Message> {
None
}

fn initial_message(
&self,
_api_key: Option<&str>,
params: &ListenParams,
_channels: u8,
) -> Option<Message> {
let language = params
.languages
.first()
.map(|l| l.iso639().code().to_string());

let session_config = SessionUpdateEvent {
event_type: "session.update".to_string(),
session: SessionConfig {
session_type: "transcription".to_string(),
audio: Some(AudioConfig {
input: Some(AudioInputConfig {
format: Some(AudioFormat {
format_type: "audio/pcm".to_string(),
rate: 24000,
}),
transcription: Some(TranscriptionConfig {
model: "gpt-4o-transcribe".to_string(),
language,
}),
turn_detection: Some(TurnDetection {
detection_type: "server_vad".to_string(),
threshold: Some(0.5),
prefix_padding_ms: Some(300),
silence_duration_ms: Some(500),
}),
}),
}),
include: Some(vec!["item.input_audio_transcription.logprobs".to_string()]),
},
};

let json = serde_json::to_string(&session_config).ok()?;
Some(Message::Text(json.into()))
}

fn finalize_message(&self) -> Message {
let commit = InputAudioBufferCommit {
event_type: "input_audio_buffer.commit".to_string(),
};
Message::Text(serde_json::to_string(&commit).unwrap().into())
}

fn parse_response(&self, raw: &str) -> Vec<StreamResponse> {
let event: OpenAIEvent = match serde_json::from_str(raw) {
Ok(e) => e,
Err(e) => {
tracing::warn!(error = ?e, raw = raw, "openai_json_parse_failed");
return vec![];
}
};

match event {
OpenAIEvent::SessionCreated { session } => {
tracing::debug!(session_id = %session.id, "openai_session_created");
vec![]
}
OpenAIEvent::SessionUpdated { session } => {
tracing::debug!(session_id = %session.id, "openai_session_updated");
vec![]
}
OpenAIEvent::InputAudioBufferCommitted { item_id } => {
tracing::debug!(item_id = %item_id, "openai_audio_buffer_committed");
vec![]
}
OpenAIEvent::InputAudioBufferCleared => {
tracing::debug!("openai_audio_buffer_cleared");
vec![]
}
OpenAIEvent::ConversationItemInputAudioTranscriptionCompleted {
item_id,
content_index,
transcript,
} => {
tracing::debug!(
item_id = %item_id,
content_index = content_index,
transcript = %transcript,
"openai_transcription_completed"
);
Self::build_transcript_response(&transcript, true, true)
}
OpenAIEvent::ConversationItemInputAudioTranscriptionDelta {
item_id,
content_index,
delta,
} => {
tracing::debug!(
item_id = %item_id,
content_index = content_index,
delta = %delta,
"openai_transcription_delta"
);
Self::build_transcript_response(&delta, false, false)
}
OpenAIEvent::ConversationItemInputAudioTranscriptionFailed {
item_id, error, ..
} => {
tracing::error!(
item_id = %item_id,
error_type = %error.error_type,
error_message = %error.message,
"openai_transcription_failed"
);
vec![]
}
OpenAIEvent::Error { error } => {
tracing::error!(
error_type = %error.error_type,
error_message = %error.message,
"openai_error"
);
vec![]
}
OpenAIEvent::Unknown => {
tracing::debug!(raw = raw, "openai_unknown_event");
vec![]
}
}
}
}

#[derive(Debug, Serialize)]
struct SessionUpdateEvent {
#[serde(rename = "type")]
event_type: String,
session: SessionConfig,
}

#[derive(Debug, Serialize)]
struct SessionConfig {
#[serde(rename = "type")]
session_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
audio: Option<AudioConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
include: Option<Vec<String>>,
}

#[derive(Debug, Serialize)]
struct AudioConfig {
#[serde(skip_serializing_if = "Option::is_none")]
input: Option<AudioInputConfig>,
}

#[derive(Debug, Serialize)]
struct AudioInputConfig {
#[serde(skip_serializing_if = "Option::is_none")]
format: Option<AudioFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
transcription: Option<TranscriptionConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
turn_detection: Option<TurnDetection>,
}

#[derive(Debug, Serialize)]
struct AudioFormat {
#[serde(rename = "type")]
format_type: String,
rate: u32,
}

#[derive(Debug, Serialize)]
struct TranscriptionConfig {
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
language: Option<String>,
}

#[derive(Debug, Serialize)]
struct TurnDetection {
#[serde(rename = "type")]
detection_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
threshold: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
prefix_padding_ms: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
silence_duration_ms: Option<u32>,
}

#[derive(Debug, Serialize)]
struct InputAudioBufferCommit {
#[serde(rename = "type")]
event_type: String,
}

#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum OpenAIEvent {
#[serde(rename = "session.created")]
SessionCreated { session: SessionInfo },
#[serde(rename = "session.updated")]
SessionUpdated { session: SessionInfo },
#[serde(rename = "input_audio_buffer.committed")]
InputAudioBufferCommitted { item_id: String },
#[serde(rename = "input_audio_buffer.cleared")]
InputAudioBufferCleared,
#[serde(rename = "conversation.item.input_audio_transcription.completed")]
ConversationItemInputAudioTranscriptionCompleted {
item_id: String,
content_index: u32,
transcript: String,
},
#[serde(rename = "conversation.item.input_audio_transcription.delta")]
ConversationItemInputAudioTranscriptionDelta {
item_id: String,
content_index: u32,
delta: String,
},
#[serde(rename = "conversation.item.input_audio_transcription.failed")]
ConversationItemInputAudioTranscriptionFailed {
item_id: String,
content_index: u32,
error: OpenAIError,
},
#[serde(rename = "error")]
Error { error: OpenAIError },
#[serde(other)]
Unknown,
}

#[derive(Debug, Deserialize)]
struct SessionInfo {
id: String,
}

#[derive(Debug, Deserialize)]
struct OpenAIError {
#[serde(rename = "type")]
error_type: String,
message: String,
}

impl OpenAIAdapter {
fn build_transcript_response(
transcript: &str,
is_final: bool,
speech_final: bool,
) -> Vec<StreamResponse> {
if transcript.is_empty() {
return vec![];
}

let words: Vec<_> = transcript
.split_whitespace()
.map(|word| WordBuilder::new(word).confidence(1.0).build())
.collect();

let (start, duration) = calculate_time_span(&words);

let channel = Channel {
alternatives: vec![Alternatives {
transcript: transcript.to_string(),
words,
confidence: 1.0,
languages: vec![],
}],
};

vec![StreamResponse::TranscriptResponse {
is_final,
speech_final,
from_finalize: false,
start,
duration,
channel,
metadata: Metadata::default(),
channel_index: vec![0, 1],
}]
}
}

#[cfg(test)]
mod tests {
use super::OpenAIAdapter;
use crate::test_utils::{run_dual_test, run_single_test};
use crate::ListenClient;

#[tokio::test]
#[ignore]
async fn test_build_single() {
let client = ListenClient::builder()
.adapter::<OpenAIAdapter>()
.api_base("wss://api.openai.com")
.api_key(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"))
.params(owhisper_interface::ListenParams {
model: Some("gpt-4o-transcribe".to_string()),
languages: vec![hypr_language::ISO639::En.into()],
..Default::default()
})
.build_single();

run_single_test(client, "openai").await;
}

#[tokio::test]
#[ignore]
async fn test_build_dual() {
let client = ListenClient::builder()
.adapter::<OpenAIAdapter>()
.api_base("wss://api.openai.com")
.api_key(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"))
.params(owhisper_interface::ListenParams {
model: Some("gpt-4o-transcribe".to_string()),
languages: vec![hypr_language::ISO639::En.into()],
..Default::default()
})
.build_dual();

run_dual_test(client, "openai").await;
}
}
Loading