Skip to content

Commit f383715

Browse files
feat: add Gladia batch STT adapter (#2116)
* feat: add Gladia realtime STT adapter Co-Authored-By: yujonglee <[email protected]> * fix: implement Gladia two-step initialization flow - Add build_ws_url_with_api_key method to RealtimeSttAdapter trait - Use ureq for blocking POST request to get session token - Fix language_config format to use object with languages array - Return None for build_auth_header since token is in URL Co-Authored-By: yujonglee <[email protected]> * fix: correct ureq dependency ordering in Cargo.toml Co-Authored-By: yujonglee <[email protected]> * Add Gladia batch STT adapter Co-Authored-By: yujonglee <[email protected]> * Fix Gladia batch adapter to send proper audio file format and export GladiaAdapter Co-Authored-By: yujonglee <[email protected]> * Add integration test for Gladia batch STT adapter Co-Authored-By: yujonglee <[email protected]> * Format batch.rs Co-Authored-By: yujonglee <[email protected]> * Fix punctuated_word to use trimmed value consistently Co-Authored-By: yujonglee <[email protected]> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
1 parent 571e632 commit f383715

File tree

3 files changed

+392
-1
lines changed

3 files changed

+392
-1
lines changed
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
use std::path::{Path, PathBuf};
2+
use std::time::Duration;
3+
4+
use owhisper_interface::batch::{
5+
Alternatives as BatchAlternatives, Channel as BatchChannel, Response as BatchResponse,
6+
Results as BatchResults, Word as BatchWord,
7+
};
8+
use owhisper_interface::ListenParams;
9+
use serde::{Deserialize, Serialize};
10+
11+
use super::GladiaAdapter;
12+
use crate::adapter::{BatchFuture, BatchSttAdapter};
13+
use crate::error::Error;
14+
use crate::polling::{poll_until, PollingConfig, PollingResult};
15+
16+
impl BatchSttAdapter for GladiaAdapter {
17+
fn transcribe_file<'a, P: AsRef<Path> + Send + 'a>(
18+
&'a self,
19+
client: &'a reqwest::Client,
20+
api_base: &'a str,
21+
api_key: &'a str,
22+
params: &'a ListenParams,
23+
file_path: P,
24+
) -> BatchFuture<'a> {
25+
let path = file_path.as_ref().to_path_buf();
26+
Box::pin(Self::do_transcribe_file(
27+
client, api_base, api_key, params, path,
28+
))
29+
}
30+
}
31+
32+
#[derive(Debug, Serialize)]
33+
struct TranscriptRequest {
34+
audio_url: String,
35+
#[serde(skip_serializing_if = "Option::is_none")]
36+
language_config: Option<LanguageConfig>,
37+
#[serde(skip_serializing_if = "Option::is_none")]
38+
diarization: Option<bool>,
39+
}
40+
41+
#[derive(Debug, Serialize)]
42+
struct LanguageConfig {
43+
#[serde(skip_serializing_if = "Vec::is_empty")]
44+
languages: Vec<String>,
45+
#[serde(skip_serializing_if = "Option::is_none")]
46+
code_switching: Option<bool>,
47+
}
48+
49+
#[derive(Debug, Deserialize)]
50+
struct UploadResponse {
51+
audio_url: String,
52+
}
53+
54+
#[derive(Debug, Deserialize)]
55+
struct InitResponse {
56+
id: String,
57+
}
58+
59+
#[derive(Debug, Deserialize)]
60+
struct TranscriptResponse {
61+
status: String,
62+
#[serde(default)]
63+
error_code: Option<String>,
64+
#[serde(default)]
65+
file: Option<FileInfo>,
66+
#[serde(default)]
67+
result: Option<TranscriptResult>,
68+
}
69+
70+
#[derive(Debug, Deserialize)]
71+
struct FileInfo {
72+
#[serde(default)]
73+
audio_duration: Option<f64>,
74+
}
75+
76+
#[derive(Debug, Deserialize)]
77+
struct TranscriptResult {
78+
#[serde(default)]
79+
metadata: Option<ResultMetadata>,
80+
#[serde(default)]
81+
transcription: Option<Transcription>,
82+
}
83+
84+
#[derive(Debug, Deserialize)]
85+
struct ResultMetadata {
86+
#[serde(default)]
87+
audio_duration: Option<f64>,
88+
}
89+
90+
#[derive(Debug, Deserialize)]
91+
struct Transcription {
92+
#[serde(default)]
93+
full_transcript: Option<String>,
94+
#[serde(default)]
95+
utterances: Vec<Utterance>,
96+
}
97+
98+
#[derive(Debug, Deserialize)]
99+
struct Utterance {
100+
text: String,
101+
#[serde(default)]
102+
start: f64,
103+
#[serde(default)]
104+
end: f64,
105+
#[serde(default)]
106+
confidence: f64,
107+
#[serde(default)]
108+
channel: usize,
109+
#[serde(default)]
110+
speaker: Option<usize>,
111+
#[serde(default)]
112+
words: Vec<GladiaWord>,
113+
}
114+
115+
#[derive(Debug, Deserialize)]
116+
struct GladiaWord {
117+
word: String,
118+
#[serde(default)]
119+
start: f64,
120+
#[serde(default)]
121+
end: f64,
122+
#[serde(default)]
123+
confidence: f64,
124+
}
125+
126+
impl GladiaAdapter {
127+
async fn do_transcribe_file(
128+
client: &reqwest::Client,
129+
api_base: &str,
130+
api_key: &str,
131+
params: &ListenParams,
132+
file_path: PathBuf,
133+
) -> Result<BatchResponse, Error> {
134+
let base_url = Self::batch_api_url(api_base);
135+
136+
let file_bytes = tokio::fs::read(&file_path)
137+
.await
138+
.map_err(|e| Error::AudioProcessing(format!("failed to read file: {}", e)))?;
139+
140+
let file_name = file_path
141+
.file_name()
142+
.and_then(|n| n.to_str())
143+
.unwrap_or("audio.wav")
144+
.to_string();
145+
146+
let mime_type = match file_path.extension().and_then(|e| e.to_str()) {
147+
Some("wav") => "audio/wav",
148+
Some("mp3") => "audio/mpeg",
149+
Some("ogg") => "audio/ogg",
150+
Some("flac") => "audio/flac",
151+
Some("m4a") => "audio/mp4",
152+
Some("webm") => "audio/webm",
153+
_ => "application/octet-stream",
154+
};
155+
156+
let upload_url = format!("{}/upload", base_url);
157+
let form = reqwest::multipart::Form::new().part(
158+
"audio",
159+
reqwest::multipart::Part::bytes(file_bytes)
160+
.file_name(file_name)
161+
.mime_str(mime_type)
162+
.map_err(|e| Error::AudioProcessing(e.to_string()))?,
163+
);
164+
165+
let upload_response = client
166+
.post(&upload_url)
167+
.header("x-gladia-key", api_key)
168+
.multipart(form)
169+
.send()
170+
.await?;
171+
172+
let upload_status = upload_response.status();
173+
if !upload_status.is_success() {
174+
return Err(Error::UnexpectedStatus {
175+
status: upload_status,
176+
body: upload_response.text().await.unwrap_or_default(),
177+
});
178+
}
179+
180+
let upload_result: UploadResponse = upload_response.json().await?;
181+
182+
let languages: Vec<String> = params
183+
.languages
184+
.iter()
185+
.map(|l| l.iso639().code().to_string())
186+
.collect();
187+
188+
let language_config = if languages.is_empty() {
189+
None
190+
} else {
191+
Some(LanguageConfig {
192+
languages,
193+
code_switching: if params.languages.len() > 1 {
194+
Some(true)
195+
} else {
196+
None
197+
},
198+
})
199+
};
200+
201+
let transcript_request = TranscriptRequest {
202+
audio_url: upload_result.audio_url,
203+
language_config,
204+
diarization: Some(true),
205+
};
206+
207+
let transcript_url = format!("{}/pre-recorded", base_url);
208+
let create_response = client
209+
.post(&transcript_url)
210+
.header("x-gladia-key", api_key)
211+
.header("Content-Type", "application/json")
212+
.json(&transcript_request)
213+
.send()
214+
.await?;
215+
216+
let create_status = create_response.status();
217+
if !create_status.is_success() {
218+
return Err(Error::UnexpectedStatus {
219+
status: create_status,
220+
body: create_response.text().await.unwrap_or_default(),
221+
});
222+
}
223+
224+
let create_result: InitResponse = create_response.json().await?;
225+
let transcript_id = create_result.id;
226+
227+
let poll_url = format!("{}/pre-recorded/{}", base_url, transcript_id);
228+
229+
let config = PollingConfig::default()
230+
.with_interval(Duration::from_secs(3))
231+
.with_timeout_error("transcription timed out".to_string());
232+
233+
poll_until(
234+
|| async {
235+
let poll_response = client
236+
.get(&poll_url)
237+
.header("x-gladia-key", api_key)
238+
.send()
239+
.await?;
240+
241+
let poll_status = poll_response.status();
242+
if !poll_status.is_success() {
243+
return Err(Error::UnexpectedStatus {
244+
status: poll_status,
245+
body: poll_response.text().await.unwrap_or_default(),
246+
});
247+
}
248+
249+
let result: TranscriptResponse = poll_response.json().await?;
250+
251+
match result.status.as_str() {
252+
"done" => Ok(PollingResult::Complete(Self::convert_to_batch_response(
253+
result,
254+
))),
255+
"error" => {
256+
let error_msg = result
257+
.error_code
258+
.unwrap_or_else(|| "unknown error".to_string());
259+
Ok(PollingResult::Failed(format!(
260+
"transcription failed: {}",
261+
error_msg
262+
)))
263+
}
264+
_ => Ok(PollingResult::Continue),
265+
}
266+
},
267+
config,
268+
)
269+
.await
270+
}
271+
272+
fn convert_to_batch_response(response: TranscriptResponse) -> BatchResponse {
273+
let result = response.result.unwrap_or(TranscriptResult {
274+
metadata: None,
275+
transcription: None,
276+
});
277+
278+
let transcription = result.transcription.unwrap_or(Transcription {
279+
full_transcript: None,
280+
utterances: Vec::new(),
281+
});
282+
283+
let words: Vec<BatchWord> = transcription
284+
.utterances
285+
.iter()
286+
.flat_map(|u| {
287+
u.words.iter().map(|w| {
288+
let trimmed = w.word.trim().to_string();
289+
BatchWord {
290+
word: trimmed.clone(),
291+
start: w.start,
292+
end: w.end,
293+
confidence: w.confidence,
294+
speaker: u.speaker,
295+
punctuated_word: Some(trimmed),
296+
}
297+
})
298+
})
299+
.collect();
300+
301+
let transcript = transcription.full_transcript.unwrap_or_default();
302+
303+
let avg_confidence = if words.is_empty() {
304+
1.0
305+
} else {
306+
words.iter().map(|w| w.confidence).sum::<f64>() / words.len() as f64
307+
};
308+
309+
let channel = BatchChannel {
310+
alternatives: vec![BatchAlternatives {
311+
transcript,
312+
confidence: avg_confidence,
313+
words,
314+
}],
315+
};
316+
317+
let audio_duration = result
318+
.metadata
319+
.and_then(|m| m.audio_duration)
320+
.or_else(|| response.file.and_then(|f| f.audio_duration));
321+
322+
BatchResponse {
323+
metadata: serde_json::json!({
324+
"audio_duration": audio_duration,
325+
}),
326+
results: BatchResults {
327+
channels: vec![channel],
328+
},
329+
}
330+
}
331+
}
332+
333+
#[cfg(test)]
334+
mod tests {
335+
use super::*;
336+
337+
#[tokio::test]
338+
#[ignore]
339+
async fn test_gladia_batch_transcription() {
340+
let api_key = std::env::var("GLADIA_API_KEY").expect("GLADIA_API_KEY not set");
341+
let client = reqwest::Client::new();
342+
let adapter = GladiaAdapter::default();
343+
let params = ListenParams::default();
344+
345+
let audio_path = std::path::PathBuf::from(hypr_data::english_1::AUDIO_PATH);
346+
347+
let result = adapter
348+
.transcribe_file(&client, "", &api_key, &params, &audio_path)
349+
.await
350+
.expect("transcription failed");
351+
352+
assert!(!result.results.channels.is_empty());
353+
assert!(!result.results.channels[0].alternatives.is_empty());
354+
assert!(!result.results.channels[0].alternatives[0]
355+
.transcript
356+
.is_empty());
357+
assert!(!result.results.channels[0].alternatives[0].words.is_empty());
358+
}
359+
}

0 commit comments

Comments
 (0)