Skip to content

Commit 1dcd7f9

Browse files
Add Gladia batch STT adapter
Co-Authored-By: yujonglee <yujonglee.dev@gmail.com>
1 parent a0fe0a6 commit 1dcd7f9

File tree

2 files changed

+381
-0
lines changed

2 files changed

+381
-0
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
use std::path::{Path, PathBuf};
2+
use std::time::Duration;
3+
4+
use hypr_audio_utils::{f32_to_i16_bytes, resample_audio, source_from_path, Source};
5+
use owhisper_interface::batch::{
6+
Alternatives as BatchAlternatives, Channel as BatchChannel, Response as BatchResponse,
7+
Results as BatchResults, Word as BatchWord,
8+
};
9+
use owhisper_interface::ListenParams;
10+
use serde::{Deserialize, Serialize};
11+
12+
use super::GladiaAdapter;
13+
use crate::adapter::{BatchFuture, BatchSttAdapter};
14+
use crate::error::Error;
15+
use crate::polling::{poll_until, PollingConfig, PollingResult};
16+
17+
impl BatchSttAdapter for GladiaAdapter {
18+
fn transcribe_file<'a, P: AsRef<Path> + Send + 'a>(
19+
&'a self,
20+
client: &'a reqwest::Client,
21+
api_base: &'a str,
22+
api_key: &'a str,
23+
params: &'a ListenParams,
24+
file_path: P,
25+
) -> BatchFuture<'a> {
26+
let path = file_path.as_ref().to_path_buf();
27+
Box::pin(Self::do_transcribe_file(
28+
client, api_base, api_key, params, path,
29+
))
30+
}
31+
}
32+
33+
#[derive(Debug, Serialize)]
34+
struct TranscriptRequest {
35+
audio_url: String,
36+
#[serde(skip_serializing_if = "Option::is_none")]
37+
language_config: Option<LanguageConfig>,
38+
#[serde(skip_serializing_if = "Option::is_none")]
39+
diarization: Option<bool>,
40+
}
41+
42+
#[derive(Debug, Serialize)]
43+
struct LanguageConfig {
44+
#[serde(skip_serializing_if = "Vec::is_empty")]
45+
languages: Vec<String>,
46+
#[serde(skip_serializing_if = "Option::is_none")]
47+
code_switching: Option<bool>,
48+
}
49+
50+
#[derive(Debug, Deserialize)]
51+
struct UploadResponse {
52+
audio_url: String,
53+
}
54+
55+
#[derive(Debug, Deserialize)]
56+
struct InitResponse {
57+
id: String,
58+
}
59+
60+
#[derive(Debug, Deserialize)]
61+
struct TranscriptResponse {
62+
status: String,
63+
#[serde(default)]
64+
error_code: Option<String>,
65+
#[serde(default)]
66+
file: Option<FileInfo>,
67+
#[serde(default)]
68+
result: Option<TranscriptResult>,
69+
}
70+
71+
#[derive(Debug, Deserialize)]
72+
struct FileInfo {
73+
#[serde(default)]
74+
audio_duration: Option<f64>,
75+
}
76+
77+
#[derive(Debug, Deserialize)]
78+
struct TranscriptResult {
79+
#[serde(default)]
80+
metadata: Option<ResultMetadata>,
81+
#[serde(default)]
82+
transcription: Option<Transcription>,
83+
}
84+
85+
#[derive(Debug, Deserialize)]
86+
struct ResultMetadata {
87+
#[serde(default)]
88+
audio_duration: Option<f64>,
89+
}
90+
91+
#[derive(Debug, Deserialize)]
92+
struct Transcription {
93+
#[serde(default)]
94+
full_transcript: Option<String>,
95+
#[serde(default)]
96+
utterances: Vec<Utterance>,
97+
}
98+
99+
#[derive(Debug, Deserialize)]
100+
struct Utterance {
101+
text: String,
102+
#[serde(default)]
103+
start: f64,
104+
#[serde(default)]
105+
end: f64,
106+
#[serde(default)]
107+
confidence: f64,
108+
#[serde(default)]
109+
channel: usize,
110+
#[serde(default)]
111+
speaker: Option<usize>,
112+
#[serde(default)]
113+
words: Vec<GladiaWord>,
114+
}
115+
116+
#[derive(Debug, Deserialize)]
117+
struct GladiaWord {
118+
word: String,
119+
#[serde(default)]
120+
start: f64,
121+
#[serde(default)]
122+
end: f64,
123+
#[serde(default)]
124+
confidence: f64,
125+
}
126+
127+
impl GladiaAdapter {
128+
async fn do_transcribe_file(
129+
client: &reqwest::Client,
130+
api_base: &str,
131+
api_key: &str,
132+
params: &ListenParams,
133+
file_path: PathBuf,
134+
) -> Result<BatchResponse, Error> {
135+
let base_url = Self::batch_api_url(api_base);
136+
137+
let audio_data = decode_audio_to_bytes(file_path).await?;
138+
139+
let upload_url = format!("{}/upload", base_url);
140+
let form = reqwest::multipart::Form::new().part(
141+
"audio",
142+
reqwest::multipart::Part::bytes(audio_data.to_vec())
143+
.file_name("audio.wav")
144+
.mime_str("audio/wav")
145+
.map_err(|e| Error::AudioProcessing(e.to_string()))?,
146+
);
147+
148+
let upload_response = client
149+
.post(&upload_url)
150+
.header("x-gladia-key", api_key)
151+
.multipart(form)
152+
.send()
153+
.await?;
154+
155+
let upload_status = upload_response.status();
156+
if !upload_status.is_success() {
157+
return Err(Error::UnexpectedStatus {
158+
status: upload_status,
159+
body: upload_response.text().await.unwrap_or_default(),
160+
});
161+
}
162+
163+
let upload_result: UploadResponse = upload_response.json().await?;
164+
165+
let languages: Vec<String> = params
166+
.languages
167+
.iter()
168+
.map(|l| l.iso639().code().to_string())
169+
.collect();
170+
171+
let language_config = if languages.is_empty() {
172+
None
173+
} else {
174+
Some(LanguageConfig {
175+
languages,
176+
code_switching: if params.languages.len() > 1 {
177+
Some(true)
178+
} else {
179+
None
180+
},
181+
})
182+
};
183+
184+
let transcript_request = TranscriptRequest {
185+
audio_url: upload_result.audio_url,
186+
language_config,
187+
diarization: Some(true),
188+
};
189+
190+
let transcript_url = format!("{}/pre-recorded", base_url);
191+
let create_response = client
192+
.post(&transcript_url)
193+
.header("x-gladia-key", api_key)
194+
.header("Content-Type", "application/json")
195+
.json(&transcript_request)
196+
.send()
197+
.await?;
198+
199+
let create_status = create_response.status();
200+
if !create_status.is_success() {
201+
return Err(Error::UnexpectedStatus {
202+
status: create_status,
203+
body: create_response.text().await.unwrap_or_default(),
204+
});
205+
}
206+
207+
let create_result: InitResponse = create_response.json().await?;
208+
let transcript_id = create_result.id;
209+
210+
let poll_url = format!("{}/pre-recorded/{}", base_url, transcript_id);
211+
212+
let config = PollingConfig::default()
213+
.with_interval(Duration::from_secs(3))
214+
.with_timeout_error("transcription timed out".to_string());
215+
216+
poll_until(
217+
|| async {
218+
let poll_response = client
219+
.get(&poll_url)
220+
.header("x-gladia-key", api_key)
221+
.send()
222+
.await?;
223+
224+
let poll_status = poll_response.status();
225+
if !poll_status.is_success() {
226+
return Err(Error::UnexpectedStatus {
227+
status: poll_status,
228+
body: poll_response.text().await.unwrap_or_default(),
229+
});
230+
}
231+
232+
let result: TranscriptResponse = poll_response.json().await?;
233+
234+
match result.status.as_str() {
235+
"done" => Ok(PollingResult::Complete(Self::convert_to_batch_response(
236+
result,
237+
))),
238+
"error" => {
239+
let error_msg = result
240+
.error_code
241+
.unwrap_or_else(|| "unknown error".to_string());
242+
Ok(PollingResult::Failed(format!(
243+
"transcription failed: {}",
244+
error_msg
245+
)))
246+
}
247+
_ => Ok(PollingResult::Continue),
248+
}
249+
},
250+
config,
251+
)
252+
.await
253+
}
254+
255+
fn convert_to_batch_response(response: TranscriptResponse) -> BatchResponse {
256+
let result = response.result.unwrap_or(TranscriptResult {
257+
metadata: None,
258+
transcription: None,
259+
});
260+
261+
let transcription = result.transcription.unwrap_or(Transcription {
262+
full_transcript: None,
263+
utterances: Vec::new(),
264+
});
265+
266+
let words: Vec<BatchWord> = transcription
267+
.utterances
268+
.iter()
269+
.flat_map(|u| {
270+
u.words.iter().map(|w| BatchWord {
271+
word: w.word.trim().to_string(),
272+
start: w.start,
273+
end: w.end,
274+
confidence: w.confidence,
275+
speaker: u.speaker,
276+
punctuated_word: Some(w.word.clone()),
277+
})
278+
})
279+
.collect();
280+
281+
let transcript = transcription.full_transcript.unwrap_or_default();
282+
283+
let avg_confidence = if words.is_empty() {
284+
1.0
285+
} else {
286+
words.iter().map(|w| w.confidence).sum::<f64>() / words.len() as f64
287+
};
288+
289+
let channel = BatchChannel {
290+
alternatives: vec![BatchAlternatives {
291+
transcript,
292+
confidence: avg_confidence,
293+
words,
294+
}],
295+
};
296+
297+
let audio_duration = result
298+
.metadata
299+
.and_then(|m| m.audio_duration)
300+
.or_else(|| response.file.and_then(|f| f.audio_duration));
301+
302+
BatchResponse {
303+
metadata: serde_json::json!({
304+
"audio_duration": audio_duration,
305+
}),
306+
results: BatchResults {
307+
channels: vec![channel],
308+
},
309+
}
310+
}
311+
}
312+
313+
async fn decode_audio_to_bytes(path: PathBuf) -> Result<bytes::Bytes, Error> {
314+
tokio::task::spawn_blocking(move || -> Result<bytes::Bytes, Error> {
315+
let decoder =
316+
source_from_path(&path).map_err(|err| Error::AudioProcessing(err.to_string()))?;
317+
318+
let channels = decoder.channels().max(1);
319+
let sample_rate = decoder.sample_rate();
320+
321+
let samples = resample_audio(decoder, sample_rate)
322+
.map_err(|err| Error::AudioProcessing(err.to_string()))?;
323+
324+
let samples = if channels == 1 {
325+
samples
326+
} else {
327+
let channels_usize = channels as usize;
328+
let mut mono = Vec::with_capacity(samples.len() / channels_usize);
329+
for frame in samples.chunks(channels_usize) {
330+
if frame.is_empty() {
331+
continue;
332+
}
333+
let sum: f32 = frame.iter().copied().sum();
334+
mono.push(sum / frame.len() as f32);
335+
}
336+
mono
337+
};
338+
339+
if samples.is_empty() {
340+
return Err(Error::AudioProcessing(
341+
"audio file contains no samples".to_string(),
342+
));
343+
}
344+
345+
let bytes = f32_to_i16_bytes(samples.into_iter());
346+
347+
Ok(bytes)
348+
})
349+
.await?
350+
}

0 commit comments

Comments
 (0)