Skip to content

Commit 8036b72

Browse files
committed
inference/tts: return InferenceResult struct
Signed-off-by: Stijn Tintel <stijn@linux-ipv6.be>
1 parent 3525e15 commit 8036b72

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

src/inference/tts.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use std::sync::{Arc, Mutex};
33

44
use anyhow::{Result, anyhow};
55
use sherpa_rs::tts::{TtsAudio, VitsTts, VitsTtsConfig};
6+
use tokio::time::Instant;
7+
8+
use crate::inference::InferenceResult;
69

710
#[derive(Clone)]
811
pub struct TtsEngine {
@@ -34,12 +37,39 @@ impl TtsEngine {
3437

3538
/// # Errors
3639
/// - when the mutex is poisoned
37-
pub fn synthesize(&self, text: &str, sid: i32, speed: f32) -> Result<TtsAudio> {
40+
pub fn synthesize(
41+
&self,
42+
text: &str,
43+
sid: i32,
44+
speed: f32,
45+
) -> Result<InferenceResult<TtsAudio>> {
3846
let mut tts = self
3947
.tts
4048
.lock()
4149
.map_err(|e| anyhow!("TTS mutex poisoned: {e:#?}"))?;
4250

43-
tts.create(text, sid, speed).map_err(|e| anyhow!("{e:#?}"))
51+
let start = Instant::now();
52+
53+
let speech = tts
54+
.create(text, sid, speed)
55+
.map_err(|e| anyhow!("{e:#?}"))?;
56+
57+
let time = start.elapsed().as_secs_f64();
58+
let time_ms = time * 1000.0;
59+
#[allow(clippy::cast_precision_loss)]
60+
let speedup = if time_ms > 0.0 {
61+
(f64::from(speech.duration)) / time_ms
62+
} else {
63+
0.0
64+
};
65+
66+
let result = InferenceResult {
67+
duration: u64::try_from(speech.duration).unwrap_or(0),
68+
output: speech,
69+
speedup,
70+
time,
71+
};
72+
73+
Ok(result)
4474
}
4575
}

src/routes/api/tts.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub async fn get(
3838
.into_response();
3939
};
4040

41+
let text = parameters.text.clone();
4142
let tts_audio = spawn_blocking(move || tts_engine.synthesize(&parameters.text, 0, 1.0)).await;
4243

4344
let Ok(Ok(tts_audio)) = tts_audio else {
@@ -48,7 +49,15 @@ pub async fn get(
4849
.into_response();
4950
};
5051

51-
let Ok(audio_bytes) = encode_wav(&tts_audio.samples, tts_audio.sample_rate) else {
52+
tracing::info!(
53+
"inference took {}s: {} - speedup: {}x",
54+
tts_audio.time,
55+
text,
56+
tts_audio.speedup
57+
);
58+
59+
let Ok(audio_bytes) = encode_wav(&tts_audio.output.samples, tts_audio.output.sample_rate)
60+
else {
5261
return (
5362
StatusCode::INTERNAL_SERVER_ERROR,
5463
String::from("failed to encode audio in WAV"),

0 commit comments

Comments
 (0)