Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions owhisper/owhisper-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tokio = { workspace = true }
tokio-stream = { workspace = true }
ureq = { version = "2", features = ["json"] }

base64 = "0.22.1"
bytes = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
Expand Down
4 changes: 4 additions & 0 deletions owhisper/owhisper-client/src/adapter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ pub trait RealtimeSttAdapter: Clone + Default + Send + Sync + 'static {

fn finalize_message(&self) -> Message;

fn audio_to_message(&self, audio: bytes::Bytes) -> Message {
Message::Binary(audio)
}

fn initial_message(
&self,
_api_key: Option<&str>,
Expand Down
52 changes: 44 additions & 8 deletions owhisper/owhisper-client/src/adapter/openai/live.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ impl RealtimeSttAdapter for OpenAIAdapter {
false
}

fn build_ws_url(&self, api_base: &str, params: &ListenParams, _channels: u8) -> url::Url {
let (mut url, existing_params) =
Self::build_ws_url_from_base(api_base, params.model.as_deref());
fn build_ws_url(&self, api_base: &str, _params: &ListenParams, _channels: u8) -> url::Url {
let (mut url, existing_params) = Self::build_ws_url_from_base(api_base);

if !existing_params.is_empty() {
let mut query_pairs = url.query_pairs_mut();
Expand All @@ -38,6 +37,16 @@ impl RealtimeSttAdapter for OpenAIAdapter {
None
}

fn audio_to_message(&self, audio: bytes::Bytes) -> Message {
use base64::Engine;
let base64_audio = base64::engine::general_purpose::STANDARD.encode(&audio);
let event = InputAudioBufferAppend {
event_type: "input_audio_buffer.append".to_string(),
audio: base64_audio,
};
Message::Text(serde_json::to_string(&event).unwrap().into())
}

fn initial_message(
&self,
_api_key: Option<&str>,
Expand All @@ -49,7 +58,10 @@ impl RealtimeSttAdapter for OpenAIAdapter {
.first()
.map(|l| l.iso639().code().to_string());

let model = params.model.as_deref().unwrap_or(super::DEFAULT_MODEL);
let model = params
.model
.as_deref()
.unwrap_or(super::DEFAULT_TRANSCRIPTION_MODEL);

let session_config = SessionUpdateEvent {
event_type: "session.update".to_string(),
Expand All @@ -59,7 +71,7 @@ impl RealtimeSttAdapter for OpenAIAdapter {
input: Some(AudioInputConfig {
format: Some(AudioFormat {
format_type: "audio/pcm".to_string(),
rate: 24000,
rate: params.sample_rate,
}),
transcription: Some(TranscriptionConfig {
model: model.to_string(),
Expand All @@ -78,6 +90,7 @@ impl RealtimeSttAdapter for OpenAIAdapter {
};

let json = serde_json::to_string(&session_config).ok()?;
tracing::debug!(payload = %json, "openai_session_update_payload");
Some(Message::Text(json.into()))
}

Expand Down Expand Up @@ -114,6 +127,14 @@ impl RealtimeSttAdapter for OpenAIAdapter {
tracing::debug!("openai_audio_buffer_cleared");
vec![]
}
OpenAIEvent::InputAudioBufferSpeechStarted { item_id } => {
tracing::debug!(item_id = %item_id, "openai_speech_started");
vec![]
}
OpenAIEvent::InputAudioBufferSpeechStopped { item_id } => {
tracing::debug!(item_id = %item_id, "openai_speech_stopped");
vec![]
}
OpenAIEvent::ConversationItemInputAudioTranscriptionCompleted {
item_id,
content_index,
Expand Down Expand Up @@ -226,6 +247,13 @@ struct TurnDetection {
silence_duration_ms: Option<u32>,
}

#[derive(Debug, Serialize)]
struct InputAudioBufferAppend {
#[serde(rename = "type")]
event_type: String,
audio: String,
}

#[derive(Debug, Serialize)]
struct InputAudioBufferCommit {
#[serde(rename = "type")]
Expand All @@ -243,6 +271,10 @@ enum OpenAIEvent {
InputAudioBufferCommitted { item_id: String },
#[serde(rename = "input_audio_buffer.cleared")]
InputAudioBufferCleared,
#[serde(rename = "input_audio_buffer.speech_started")]
InputAudioBufferSpeechStarted { item_id: String },
#[serde(rename = "input_audio_buffer.speech_stopped")]
InputAudioBufferSpeechStopped { item_id: String },
#[serde(rename = "conversation.item.input_audio_transcription.completed")]
ConversationItemInputAudioTranscriptionCompleted {
item_id: String,
Expand Down Expand Up @@ -321,9 +353,11 @@ impl OpenAIAdapter {
#[cfg(test)]
mod tests {
use super::OpenAIAdapter;
use crate::test_utils::{run_dual_test, run_single_test};
use crate::test_utils::{run_dual_test_with_rate, run_single_test_with_rate};
use crate::ListenClient;

const OPENAI_SAMPLE_RATE: u32 = 24000;

#[tokio::test]
#[ignore]
async fn test_build_single() {
Expand All @@ -334,11 +368,12 @@ mod tests {
.params(owhisper_interface::ListenParams {
model: Some("gpt-4o-transcribe".to_string()),
languages: vec![hypr_language::ISO639::En.into()],
sample_rate: OPENAI_SAMPLE_RATE,
..Default::default()
})
.build_single();

run_single_test(client, "openai").await;
run_single_test_with_rate(client, "openai", OPENAI_SAMPLE_RATE).await;
}

#[tokio::test]
Expand All @@ -351,10 +386,11 @@ mod tests {
.params(owhisper_interface::ListenParams {
model: Some("gpt-4o-transcribe".to_string()),
languages: vec![hypr_language::ISO639::En.into()],
sample_rate: OPENAI_SAMPLE_RATE,
..Default::default()
})
.build_dual();

run_dual_test(client, "openai").await;
run_dual_test_with_rate(client, "openai", OPENAI_SAMPLE_RATE).await;
}
}
37 changes: 9 additions & 28 deletions owhisper/owhisper-client/src/adapter/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod live;

pub(crate) const DEFAULT_WS_HOST: &str = "api.openai.com";
pub(crate) const WS_PATH: &str = "/v1/realtime";
pub(crate) const DEFAULT_MODEL: &str = "gpt-4o-transcribe";
pub(crate) const DEFAULT_TRANSCRIPTION_MODEL: &str = "gpt-4o-transcribe";

#[derive(Clone, Default)]
pub struct OpenAIAdapter;
Expand All @@ -21,17 +21,13 @@ impl OpenAIAdapter {
host.contains("openai.com")
}

pub(crate) fn build_ws_url_from_base(
api_base: &str,
model: Option<&str>,
) -> (url::Url, Vec<(String, String)>) {
pub(crate) fn build_ws_url_from_base(api_base: &str) -> (url::Url, Vec<(String, String)>) {
if api_base.is_empty() {
let model = model.unwrap_or(DEFAULT_MODEL);
return (
format!("wss://{}{}", DEFAULT_WS_HOST, WS_PATH)
.parse()
.expect("invalid_default_ws_url"),
vec![("model".to_string(), model.to_string())],
vec![("intent".to_string(), "transcription".to_string())],
);
}

Expand All @@ -42,9 +38,8 @@ impl OpenAIAdapter {
let parsed: url::Url = api_base.parse().expect("invalid_api_base");
let mut existing_params = super::extract_query_params(&parsed);

if !existing_params.iter().any(|(k, _)| k == "model") {
let model = model.unwrap_or(DEFAULT_MODEL);
existing_params.push(("model".to_string(), model.to_string()));
if !existing_params.iter().any(|(k, _)| k == "intent") {
existing_params.push(("intent".to_string(), "transcription".to_string()));
}

let host = parsed.host_str().unwrap_or(DEFAULT_WS_HOST);
Expand All @@ -64,40 +59,26 @@ mod tests {

#[test]
fn test_build_ws_url_from_base_empty() {
let (url, params) = OpenAIAdapter::build_ws_url_from_base("", None);
let (url, params) = OpenAIAdapter::build_ws_url_from_base("");
assert_eq!(url.as_str(), "wss://api.openai.com/v1/realtime");
assert_eq!(
params,
vec![("model".to_string(), "gpt-4o-transcribe".to_string())]
);
}

#[test]
fn test_build_ws_url_from_base_with_model() {
let (url, params) =
OpenAIAdapter::build_ws_url_from_base("", Some("gpt-4o-mini-realtime-preview"));
assert_eq!(url.as_str(), "wss://api.openai.com/v1/realtime");
assert_eq!(
params,
vec![(
"model".to_string(),
"gpt-4o-mini-realtime-preview".to_string()
)]
vec![("intent".to_string(), "transcription".to_string())]
);
}

#[test]
fn test_build_ws_url_from_base_proxy() {
let (url, params) =
OpenAIAdapter::build_ws_url_from_base("https://api.hyprnote.com?provider=openai", None);
OpenAIAdapter::build_ws_url_from_base("https://api.hyprnote.com?provider=openai");
assert_eq!(url.as_str(), "wss://api.hyprnote.com/listen");
assert_eq!(params, vec![("provider".to_string(), "openai".to_string())]);
}

#[test]
fn test_build_ws_url_from_base_localhost() {
let (url, params) =
OpenAIAdapter::build_ws_url_from_base("http://localhost:8787?provider=openai", None);
OpenAIAdapter::build_ws_url_from_base("http://localhost:8787?provider=openai");
assert_eq!(url.as_str(), "ws://localhost:8787/listen");
assert_eq!(params, vec![("provider".to_string(), "openai".to_string())]);
}
Expand Down
Loading