From 57ff30c6d06e60a367d57ac734a37b6fc16871f8 Mon Sep 17 00:00:00 2001 From: Zack Angelo Date: Thu, 17 Jul 2025 14:42:08 -0700 Subject: [PATCH 1/3] fix typo in serialization for Container::None --- src/speak/options.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speak/options.rs b/src/speak/options.rs index eb061613..b66bd101 100644 --- a/src/speak/options.rs +++ b/src/speak/options.rs @@ -144,7 +144,7 @@ impl Container { match self { Container::Wav => "wav", Container::Ogg => "ogg", - Container::None => "nonne", + Container::None => "none", Container::CustomContainer(container) => container, } } From 2a1e140d02e320e4b312ffe2588e4b863610a13b Mon Sep 17 00:00:00 2001 From: Zack Angelo Date: Sat, 19 Jul 2025 10:00:36 -0700 Subject: [PATCH 2/3] websocket tts implementation --- src/speak/mod.rs | 1 + src/speak/websocket.rs | 372 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 373 insertions(+) create mode 100644 src/speak/websocket.rs diff --git a/src/speak/mod.rs b/src/speak/mod.rs index 7591494a..550cd83b 100644 --- a/src/speak/mod.rs +++ b/src/speak/mod.rs @@ -2,3 +2,4 @@ pub mod options; pub mod rest; +pub mod websocket; diff --git a/src/speak/websocket.rs b/src/speak/websocket.rs new file mode 100644 index 00000000..8a50cab0 --- /dev/null +++ b/src/speak/websocket.rs @@ -0,0 +1,372 @@ +#![allow(missing_docs)] +//! WebSocket TTS module + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{ + speak::options::{Encoding, Model}, + Deepgram, DeepgramError, Result, Speak, +}; + +use anyhow::anyhow; +use bytes::Bytes; +use futures::{select, SinkExt, Stream, StreamExt}; +use http::Request; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; +use tungstenite::{handshake::client, Message}; +use url::Url; +use uuid::Uuid; + +static TTS_STREAM_PATH: &str = "v1/speak"; + +/// TODO docs +#[derive(Clone, Debug)] +pub struct WebsocketBuilder<'a> { + deepgram: &'a Deepgram, + encoding: Option, + model: Option, + sample_rate: Option, +} + +impl<'a> WebsocketBuilder<'a> { + pub fn as_url(&self) -> Result { + let mut url = + self.deepgram.base_url.join(TTS_STREAM_PATH).expect( + "base_url is checked to be a valid base_url when constructing Deepgram client", + ); + + match url.scheme() { + "http" | "ws" => url.set_scheme("ws").expect("a valid conversion according to the .set_scheme docs"), + "https" | "wss" => url.set_scheme("wss").expect("a valid conversion according to the .set_scheme docs"), + _ => unreachable!("base_url is validated to have a scheme of http, https, ws, or wss when constructing Deepgram client"), + } + + { + let mut pairs = url.query_pairs_mut(); + + if let Some(encoding) = self.encoding.as_ref() { + pairs.append_pair("encoding", encoding.as_str()); + } + + if let Some(model) = self.model.as_ref() { + pairs.append_pair("model", model.as_ref()); + } + + if let Some(sample_rate) = self.sample_rate { + pairs.append_pair("sample_rate", sample_rate.to_string().as_str()); + } + } + + Ok(url) + } + + pub async fn handle(self) -> Result { + WebsocketHandle::new(self).await + } + + pub async fn stream(self, stream: S) -> Result + where + S: Stream> + Send + Unpin + 'static, + E: std::error::Error + Send + Sync + 'static, + { + let handle = self.handle().await?; + let request_tx = handle.message_tx; + let mut text_stream = stream.fuse(); + let mut response_rx = ReceiverStream::new(handle.response_rx).fuse(); + + tokio::task::spawn(async move { + loop { + select! { + t = text_stream.next() => { + eprintln!("Text stream: {:?}", t); + match t { + Some(Ok(text)) => { + if let Err(_) = request_tx.send(SpeakWsMessage::Speak { text }).await { + break; + } + } + Some(Err(_err)) => { + break; + } + None => { + //when the text input stream closes, queue a close command + //on the websocket channel + let _ = request_tx.send(SpeakWsMessage::Close).await; + } + } + } + r = response_rx.next() => { + eprintln!("Response: {:?}", r); + } + } + } + }); + + let audio_stream = SpeakAudioStream { + rx: handle.audio_rx, + }; + + Ok(audio_stream) + } +} + +/// TODO docs +#[derive(Debug)] +pub struct WebsocketHandle { + message_tx: mpsc::Sender, + response_rx: mpsc::Receiver>, + audio_rx: mpsc::Receiver>, + request_id: Uuid, +} + +impl WebsocketHandle { + async fn new(builder: WebsocketBuilder<'_>) -> Result { + let url = builder.as_url()?; + let host = url.host_str().ok_or(DeepgramError::InvalidUrl)?; + + let request = { + let http_builder = Request::builder() + .method("GET") + .uri(url.to_string()) + .header("sec-websocket-key", client::generate_key()) + .header("host", host) + .header("connection", "upgrade") + .header("upgrade", "websocket") + .header("sec-websocket-version", "13"); + + let builder = if let Some(auth) = &builder.deepgram.auth { + http_builder.header("authorization", auth.header_value()) + } else { + http_builder + }; + builder.body(())? + }; + + eprintln!("WS Speech Request: {:?}", request); + + let (ws_stream, upgrade_response) = tokio_tungstenite::connect_async(request).await?; + + let request_id = upgrade_response + .headers() + .get("dg-request-id") + .ok_or(DeepgramError::UnexpectedServerResponse(anyhow!( + "Websocket upgrade headers missing request ID" + )))? + .to_str() + .ok() + .and_then(|req_header_str| Uuid::parse_str(req_header_str).ok()) + .ok_or(DeepgramError::UnexpectedServerResponse(anyhow!( + "Received malformed request ID in websocket upgrade headers" + )))?; + + let (message_tx, message_rx) = mpsc::channel(256); + let (response_tx, response_rx) = mpsc::channel(256); + let (audio_tx, audio_rx) = mpsc::channel(256); + + tokio::task::spawn({ + let worker = WsWorker::new(ws_stream, message_rx, response_tx, audio_tx); + + async move { + if let Err(err) = worker.run().await { + tracing::error!("speak websocket worker error: {:?}", err); + } + } + }); + + Ok(WebsocketHandle { + message_tx, + response_rx, + audio_rx, + request_id, + }) + } + + pub fn request_id(&self) -> Uuid { + self.request_id + } + + pub async fn send_text(&self, text: String) -> Result<()> { + eprintln!("Sending text: {}", text); + if let Err(_) = self.message_tx.send(SpeakWsMessage::Speak { text }).await { + return Err(DeepgramError::UnexpectedServerResponse(anyhow!( + "websocket closed" + ))); + } + + Ok(()) + } + + pub async fn flush(&self) -> Result<()> { + let _ = self.message_tx.send(SpeakWsMessage::Flush).await; + Ok(()) + } +} + +#[derive(Debug)] +pub struct SpeakAudioStream { + rx: mpsc::Receiver>, +} + +impl Stream for SpeakAudioStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().rx.poll_recv(cx) + } +} + +impl<'a> Speak<'a> { + /// Opens a websocket connection to the Deepgram API to birectionally + /// stream text input and audio output + pub fn continuous_speak_to_stream(&self) -> WebsocketBuilder<'_> { + WebsocketBuilder { + deepgram: self.0, + encoding: Some(Encoding::Linear16), + model: Some(Model::CustomId("aura-2-thalia-en".to_string())), + sample_rate: Some(24000), + } + } +} + +/// TODO docs +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +pub enum SpeakWsMessage { + Speak { text: String }, + Flush, + Clear, + Close, +} + +/// TODO docs +#[derive(Debug)] +pub enum StreamResponse { + Audio(Bytes), + Control(SpeakResponse), +} + +/// TODO docs +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +pub enum SpeakResponse { + Flush { + sequence_id: u64, + }, + Clear { + sequence_id: u64, + }, + Close { + sequence_id: u64, + }, + StreamClosed { + code: u64, + reason: Option, + }, + Metadata { + request_id: String, + model_name: String, + }, +} + +#[derive(Debug)] +pub struct WsWorker { + ws_stream: WebSocketStream>, + request_rx: mpsc::Receiver, + response_tx: mpsc::Sender>, + audio_tx: mpsc::Sender>, +} + +impl WsWorker { + pub fn new( + ws_stream: WebSocketStream>, + request_rx: mpsc::Receiver, + response_tx: mpsc::Sender>, + audio_tx: mpsc::Sender>, + ) -> Self { + Self { + ws_stream, + request_rx, + response_tx, + audio_tx, + } + } + + async fn run(self) -> Result<()> { + let (mut ws_stream_send, ws_stream_recv) = self.ws_stream.split(); + let mut ws_recv = ws_stream_recv.fuse(); + let mut request_rx = ReceiverStream::new(self.request_rx).fuse(); + + loop { + select! { + response = ws_recv.next() => { + match response { + Some(Ok(Message::Text(response))) => { + eprintln!("Received text: {}", response); + match serde_json::from_str::(&response) { + Ok(response) => { + if (self.response_tx.send(Ok(response)).await).is_err() { + break; + } + } + Err(err) => { + if (self.response_tx.send(Err(err.into())).await).is_err() { + break; + } + } + } + } + Some(Ok(Message::Binary(audio))) => { + eprintln!("Received audio"); + if (self.audio_tx.send(Ok(audio)).await).is_err() { + break; + } + } + Some(Ok(Message::Close(_))) => { + return Ok(()) + } + Some(Ok(Message::Ping(ping))) => { + // We don't really care if the server receives the pong. + let _ = ws_stream_send.send(Message::Pong(ping)).await; + } + Some(Ok(Message::Pong(_))) => { } + Some(Ok(Message::Frame(_))) => { + eprintln!("Received frame"); + // We don't care about frames (I think). + } + Some(Err(err)) => { + if (self.response_tx.send(Err(err.into())).await).is_err() { + break; + } + } + None => { + return Ok(()) + } + } + } + + request = request_rx.next() => { + match request { + Some(request) => { + let msg = serde_json::to_string(&request)?; + eprintln!("Sending message: {}", msg); + if let Err(_) = ws_stream_send.send(Message::Text(msg.into())).await { + break; + } + } + None => { + return Ok(()) + } + } + } + } + } + + Ok(()) + } +} From b803517466a495e3d444a294954a7ef93f5d859d Mon Sep 17 00:00:00 2001 From: Zack Angelo Date: Thu, 24 Jul 2025 21:29:47 -0700 Subject: [PATCH 3/3] wip --- src/speak/options.rs | 4 ++++ src/speak/websocket.rs | 43 +++++++++++++++++++++++++++++++----------- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/speak/options.rs b/src/speak/options.rs index b66bd101..6b3b5fcc 100644 --- a/src/speak/options.rs +++ b/src/speak/options.rs @@ -50,6 +50,9 @@ pub enum Model { #[allow(missing_docs)] AuraZeusEn, + #[allow(missing_docs)] + Aura2ThaliaEn, + #[allow(missing_docs)] CustomId(String), } @@ -69,6 +72,7 @@ impl AsRef for Model { Self::AuraOrpheusEn => "aura-orpheus-en", Self::AuraHeliosEn => "aura-helios-en", Self::AuraZeusEn => "aura-zeus-en", + Self::Aura2ThaliaEn => "aura-2-thalia-en", Self::CustomId(id) => id, } } diff --git a/src/speak/websocket.rs b/src/speak/websocket.rs index 8a50cab0..b03aad97 100644 --- a/src/speak/websocket.rs +++ b/src/speak/websocket.rs @@ -32,6 +32,7 @@ pub struct WebsocketBuilder<'a> { encoding: Option, model: Option, sample_rate: Option, + no_delay: bool, } impl<'a> WebsocketBuilder<'a> { @@ -70,6 +71,26 @@ impl<'a> WebsocketBuilder<'a> { WebsocketHandle::new(self).await } + pub fn encoding(mut self, encoding: Encoding) -> Self { + self.encoding = Some(encoding); + self + } + + pub fn sample_rate(mut self, sample_rate: u32) -> Self { + self.sample_rate = Some(sample_rate); + self + } + + pub fn model(mut self, model: Model) -> Self { + self.model = Some(model); + self + } + + pub fn no_delay(mut self, no_delay: bool) -> Self { + self.no_delay = no_delay; + self + } + pub async fn stream(self, stream: S) -> Result where S: Stream> + Send + Unpin + 'static, @@ -84,7 +105,6 @@ impl<'a> WebsocketBuilder<'a> { loop { select! { t = text_stream.next() => { - eprintln!("Text stream: {:?}", t); match t { Some(Ok(text)) => { if let Err(_) = request_tx.send(SpeakWsMessage::Speak { text }).await { @@ -102,7 +122,9 @@ impl<'a> WebsocketBuilder<'a> { } } r = response_rx.next() => { - eprintln!("Response: {:?}", r); + if r.is_none() { + break; + } } } } @@ -148,9 +170,8 @@ impl WebsocketHandle { builder.body(())? }; - eprintln!("WS Speech Request: {:?}", request); - - let (ws_stream, upgrade_response) = tokio_tungstenite::connect_async(request).await?; + let (ws_stream, upgrade_response) = + tokio_tungstenite::connect_async_with_config(request, None, !builder.no_delay).await?; let request_id = upgrade_response .headers() @@ -192,7 +213,6 @@ impl WebsocketHandle { } pub async fn send_text(&self, text: String) -> Result<()> { - eprintln!("Sending text: {}", text); if let Err(_) = self.message_tx.send(SpeakWsMessage::Speak { text }).await { return Err(DeepgramError::UnexpectedServerResponse(anyhow!( "websocket closed" @@ -227,9 +247,10 @@ impl<'a> Speak<'a> { pub fn continuous_speak_to_stream(&self) -> WebsocketBuilder<'_> { WebsocketBuilder { deepgram: self.0, - encoding: Some(Encoding::Linear16), - model: Some(Model::CustomId("aura-2-thalia-en".to_string())), - sample_rate: Some(24000), + encoding: None, + model: None, + sample_rate: None, + no_delay: false, } } } @@ -322,7 +343,7 @@ impl WsWorker { } } Some(Ok(Message::Binary(audio))) => { - eprintln!("Received audio"); + // eprintln!("Received audio"); if (self.audio_tx.send(Ok(audio)).await).is_err() { break; } @@ -336,7 +357,7 @@ impl WsWorker { } Some(Ok(Message::Pong(_))) => { } Some(Ok(Message::Frame(_))) => { - eprintln!("Received frame"); + // eprintln!("Received frame"); // We don't care about frames (I think). } Some(Err(err)) => {