diff --git a/owhisper/owhisper-client/src/adapter/gladia/batch.rs b/owhisper/owhisper-client/src/adapter/gladia/batch.rs new file mode 100644 index 0000000000..c13e130b4b --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/gladia/batch.rs @@ -0,0 +1,359 @@ +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use owhisper_interface::batch::{ + Alternatives as BatchAlternatives, Channel as BatchChannel, Response as BatchResponse, + Results as BatchResults, Word as BatchWord, +}; +use owhisper_interface::ListenParams; +use serde::{Deserialize, Serialize}; + +use super::GladiaAdapter; +use crate::adapter::{BatchFuture, BatchSttAdapter}; +use crate::error::Error; +use crate::polling::{poll_until, PollingConfig, PollingResult}; + +impl BatchSttAdapter for GladiaAdapter { + fn transcribe_file<'a, P: AsRef + Send + 'a>( + &'a self, + client: &'a reqwest::Client, + api_base: &'a str, + api_key: &'a str, + params: &'a ListenParams, + file_path: P, + ) -> BatchFuture<'a> { + let path = file_path.as_ref().to_path_buf(); + Box::pin(Self::do_transcribe_file( + client, api_base, api_key, params, path, + )) + } +} + +#[derive(Debug, Serialize)] +struct TranscriptRequest { + audio_url: String, + #[serde(skip_serializing_if = "Option::is_none")] + language_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + diarization: Option, +} + +#[derive(Debug, Serialize)] +struct LanguageConfig { + #[serde(skip_serializing_if = "Vec::is_empty")] + languages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + code_switching: Option, +} + +#[derive(Debug, Deserialize)] +struct UploadResponse { + audio_url: String, +} + +#[derive(Debug, Deserialize)] +struct InitResponse { + id: String, +} + +#[derive(Debug, Deserialize)] +struct TranscriptResponse { + status: String, + #[serde(default)] + error_code: Option, + #[serde(default)] + file: Option, + #[serde(default)] + result: Option, +} + +#[derive(Debug, Deserialize)] +struct FileInfo { + #[serde(default)] + audio_duration: Option, +} + +#[derive(Debug, Deserialize)] +struct TranscriptResult { + #[serde(default)] + metadata: Option, + #[serde(default)] + transcription: Option, +} + +#[derive(Debug, Deserialize)] +struct ResultMetadata { + #[serde(default)] + audio_duration: Option, +} + +#[derive(Debug, Deserialize)] +struct Transcription { + #[serde(default)] + full_transcript: Option, + #[serde(default)] + utterances: Vec, +} + +#[derive(Debug, Deserialize)] +struct Utterance { + text: String, + #[serde(default)] + start: f64, + #[serde(default)] + end: f64, + #[serde(default)] + confidence: f64, + #[serde(default)] + channel: usize, + #[serde(default)] + speaker: Option, + #[serde(default)] + words: Vec, +} + +#[derive(Debug, Deserialize)] +struct GladiaWord { + word: String, + #[serde(default)] + start: f64, + #[serde(default)] + end: f64, + #[serde(default)] + confidence: f64, +} + +impl GladiaAdapter { + async fn do_transcribe_file( + client: &reqwest::Client, + api_base: &str, + api_key: &str, + params: &ListenParams, + file_path: PathBuf, + ) -> Result { + let base_url = Self::batch_api_url(api_base); + + let file_bytes = tokio::fs::read(&file_path) + .await + .map_err(|e| Error::AudioProcessing(format!("failed to read file: {}", e)))?; + + let file_name = file_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("audio.wav") + .to_string(); + + let mime_type = match file_path.extension().and_then(|e| e.to_str()) { + Some("wav") => "audio/wav", + Some("mp3") => "audio/mpeg", + Some("ogg") => "audio/ogg", + Some("flac") => "audio/flac", + Some("m4a") => "audio/mp4", + Some("webm") => "audio/webm", + _ => "application/octet-stream", + }; + + let upload_url = format!("{}/upload", base_url); + let form = reqwest::multipart::Form::new().part( + "audio", + reqwest::multipart::Part::bytes(file_bytes) + .file_name(file_name) + .mime_str(mime_type) + .map_err(|e| Error::AudioProcessing(e.to_string()))?, + ); + + let upload_response = client + .post(&upload_url) + .header("x-gladia-key", api_key) + .multipart(form) + .send() + .await?; + + let upload_status = upload_response.status(); + if !upload_status.is_success() { + return Err(Error::UnexpectedStatus { + status: upload_status, + body: upload_response.text().await.unwrap_or_default(), + }); + } + + let upload_result: UploadResponse = upload_response.json().await?; + + let languages: Vec = params + .languages + .iter() + .map(|l| l.iso639().code().to_string()) + .collect(); + + let language_config = if languages.is_empty() { + None + } else { + Some(LanguageConfig { + languages, + code_switching: if params.languages.len() > 1 { + Some(true) + } else { + None + }, + }) + }; + + let transcript_request = TranscriptRequest { + audio_url: upload_result.audio_url, + language_config, + diarization: Some(true), + }; + + let transcript_url = format!("{}/pre-recorded", base_url); + let create_response = client + .post(&transcript_url) + .header("x-gladia-key", api_key) + .header("Content-Type", "application/json") + .json(&transcript_request) + .send() + .await?; + + let create_status = create_response.status(); + if !create_status.is_success() { + return Err(Error::UnexpectedStatus { + status: create_status, + body: create_response.text().await.unwrap_or_default(), + }); + } + + let create_result: InitResponse = create_response.json().await?; + let transcript_id = create_result.id; + + let poll_url = format!("{}/pre-recorded/{}", base_url, transcript_id); + + let config = PollingConfig::default() + .with_interval(Duration::from_secs(3)) + .with_timeout_error("transcription timed out".to_string()); + + poll_until( + || async { + let poll_response = client + .get(&poll_url) + .header("x-gladia-key", api_key) + .send() + .await?; + + let poll_status = poll_response.status(); + if !poll_status.is_success() { + return Err(Error::UnexpectedStatus { + status: poll_status, + body: poll_response.text().await.unwrap_or_default(), + }); + } + + let result: TranscriptResponse = poll_response.json().await?; + + match result.status.as_str() { + "done" => Ok(PollingResult::Complete(Self::convert_to_batch_response( + result, + ))), + "error" => { + let error_msg = result + .error_code + .unwrap_or_else(|| "unknown error".to_string()); + Ok(PollingResult::Failed(format!( + "transcription failed: {}", + error_msg + ))) + } + _ => Ok(PollingResult::Continue), + } + }, + config, + ) + .await + } + + fn convert_to_batch_response(response: TranscriptResponse) -> BatchResponse { + let result = response.result.unwrap_or(TranscriptResult { + metadata: None, + transcription: None, + }); + + let transcription = result.transcription.unwrap_or(Transcription { + full_transcript: None, + utterances: Vec::new(), + }); + + let words: Vec = transcription + .utterances + .iter() + .flat_map(|u| { + u.words.iter().map(|w| { + let trimmed = w.word.trim().to_string(); + BatchWord { + word: trimmed.clone(), + start: w.start, + end: w.end, + confidence: w.confidence, + speaker: u.speaker, + punctuated_word: Some(trimmed), + } + }) + }) + .collect(); + + let transcript = transcription.full_transcript.unwrap_or_default(); + + let avg_confidence = if words.is_empty() { + 1.0 + } else { + words.iter().map(|w| w.confidence).sum::() / words.len() as f64 + }; + + let channel = BatchChannel { + alternatives: vec![BatchAlternatives { + transcript, + confidence: avg_confidence, + words, + }], + }; + + let audio_duration = result + .metadata + .and_then(|m| m.audio_duration) + .or_else(|| response.file.and_then(|f| f.audio_duration)); + + BatchResponse { + metadata: serde_json::json!({ + "audio_duration": audio_duration, + }), + results: BatchResults { + channels: vec![channel], + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[ignore] + async fn test_gladia_batch_transcription() { + let api_key = std::env::var("GLADIA_API_KEY").expect("GLADIA_API_KEY not set"); + let client = reqwest::Client::new(); + let adapter = GladiaAdapter::default(); + let params = ListenParams::default(); + + let audio_path = std::path::PathBuf::from(hypr_data::english_1::AUDIO_PATH); + + let result = adapter + .transcribe_file(&client, "", &api_key, ¶ms, &audio_path) + .await + .expect("transcription failed"); + + assert!(!result.results.channels.is_empty()); + assert!(!result.results.channels[0].alternatives.is_empty()); + assert!(!result.results.channels[0].alternatives[0] + .transcript + .is_empty()); + assert!(!result.results.channels[0].alternatives[0].words.is_empty()); + } +} diff --git a/owhisper/owhisper-client/src/adapter/gladia/mod.rs b/owhisper/owhisper-client/src/adapter/gladia/mod.rs index 8fdd4f7f5c..b5b2307e78 100644 --- a/owhisper/owhisper-client/src/adapter/gladia/mod.rs +++ b/owhisper/owhisper-client/src/adapter/gladia/mod.rs @@ -1,7 +1,9 @@ +mod batch; mod live; pub(crate) const DEFAULT_API_HOST: &str = "api.gladia.io"; pub(crate) const WS_PATH: &str = "/v2/live"; +const API_BASE: &str = "https://api.gladia.io/v2"; #[derive(Clone, Default)] pub struct GladiaAdapter; @@ -52,6 +54,15 @@ impl GladiaAdapter { .expect("invalid_ws_url"); (url, existing_params) } + + pub(crate) fn batch_api_url(api_base: &str) -> url::Url { + if api_base.is_empty() { + return API_BASE.parse().expect("invalid_default_api_url"); + } + + let url: url::Url = api_base.parse().expect("invalid_api_base"); + url + } } #[cfg(test)] @@ -98,4 +109,24 @@ mod tests { ); } } + + #[test] + fn test_is_host() { + assert!(GladiaAdapter::is_host("https://api.gladia.io")); + assert!(GladiaAdapter::is_host("https://api.gladia.io/v2")); + assert!(!GladiaAdapter::is_host("https://api.deepgram.com")); + assert!(!GladiaAdapter::is_host("https://api.assemblyai.com")); + } + + #[test] + fn test_batch_api_url_empty_uses_default() { + let url = GladiaAdapter::batch_api_url(""); + assert_eq!(url.as_str(), "https://api.gladia.io/v2"); + } + + #[test] + fn test_batch_api_url_custom() { + let url = GladiaAdapter::batch_api_url("https://custom.gladia.io/v2"); + assert_eq!(url.as_str(), "https://custom.gladia.io/v2"); + } } diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index 18ecd158d3..4dd2a33cd3 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -11,7 +11,8 @@ use std::marker::PhantomData; pub use adapter::{ append_provider_param, is_local_host, AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, - BatchSttAdapter, DeepgramAdapter, FireworksAdapter, RealtimeSttAdapter, SonioxAdapter, + BatchSttAdapter, DeepgramAdapter, FireworksAdapter, GladiaAdapter, RealtimeSttAdapter, + SonioxAdapter, }; pub use batch::BatchClient; pub use error::Error;