diff --git a/src/speak/mod.rs b/src/speak/mod.rs index 7591494a..550cd83b 100644 --- a/src/speak/mod.rs +++ b/src/speak/mod.rs @@ -2,3 +2,4 @@ pub mod options; pub mod rest; +pub mod websocket; diff --git a/src/speak/options.rs b/src/speak/options.rs index eb061613..6b3b5fcc 100644 --- a/src/speak/options.rs +++ b/src/speak/options.rs @@ -50,6 +50,9 @@ pub enum Model { #[allow(missing_docs)] AuraZeusEn, + #[allow(missing_docs)] + Aura2ThaliaEn, + #[allow(missing_docs)] CustomId(String), } @@ -69,6 +72,7 @@ impl AsRef for Model { Self::AuraOrpheusEn => "aura-orpheus-en", Self::AuraHeliosEn => "aura-helios-en", Self::AuraZeusEn => "aura-zeus-en", + Self::Aura2ThaliaEn => "aura-2-thalia-en", Self::CustomId(id) => id, } } @@ -144,7 +148,7 @@ impl Container { match self { Container::Wav => "wav", Container::Ogg => "ogg", - Container::None => "nonne", + Container::None => "none", Container::CustomContainer(container) => container, } } diff --git a/src/speak/websocket.rs b/src/speak/websocket.rs new file mode 100644 index 00000000..b03aad97 --- /dev/null +++ b/src/speak/websocket.rs @@ -0,0 +1,393 @@ +#![allow(missing_docs)] +//! WebSocket TTS module + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{ + speak::options::{Encoding, Model}, + Deepgram, DeepgramError, Result, Speak, +}; + +use anyhow::anyhow; +use bytes::Bytes; +use futures::{select, SinkExt, Stream, StreamExt}; +use http::Request; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; +use tungstenite::{handshake::client, Message}; +use url::Url; +use uuid::Uuid; + +static TTS_STREAM_PATH: &str = "v1/speak"; + +/// TODO docs +#[derive(Clone, Debug)] +pub struct WebsocketBuilder<'a> { + deepgram: &'a Deepgram, + encoding: Option, + model: Option, + sample_rate: Option, + no_delay: bool, +} + +impl<'a> WebsocketBuilder<'a> { + pub fn as_url(&self) -> Result { + let mut url = + self.deepgram.base_url.join(TTS_STREAM_PATH).expect( + "base_url is checked to be a valid base_url when constructing Deepgram client", + ); + + match url.scheme() { + "http" | "ws" => url.set_scheme("ws").expect("a valid conversion according to the .set_scheme docs"), + "https" | "wss" => url.set_scheme("wss").expect("a valid conversion according to the .set_scheme docs"), + _ => unreachable!("base_url is validated to have a scheme of http, https, ws, or wss when constructing Deepgram client"), + } + + { + let mut pairs = url.query_pairs_mut(); + + if let Some(encoding) = self.encoding.as_ref() { + pairs.append_pair("encoding", encoding.as_str()); + } + + if let Some(model) = self.model.as_ref() { + pairs.append_pair("model", model.as_ref()); + } + + if let Some(sample_rate) = self.sample_rate { + pairs.append_pair("sample_rate", sample_rate.to_string().as_str()); + } + } + + Ok(url) + } + + pub async fn handle(self) -> Result { + WebsocketHandle::new(self).await + } + + pub fn encoding(mut self, encoding: Encoding) -> Self { + self.encoding = Some(encoding); + self + } + + pub fn sample_rate(mut self, sample_rate: u32) -> Self { + self.sample_rate = Some(sample_rate); + self + } + + pub fn model(mut self, model: Model) -> Self { + self.model = Some(model); + self + } + + pub fn no_delay(mut self, no_delay: bool) -> Self { + self.no_delay = no_delay; + self + } + + pub async fn stream(self, stream: S) -> Result + where + S: Stream> + Send + Unpin + 'static, + E: std::error::Error + Send + Sync + 'static, + { + let handle = self.handle().await?; + let request_tx = handle.message_tx; + let mut text_stream = stream.fuse(); + let mut response_rx = ReceiverStream::new(handle.response_rx).fuse(); + + tokio::task::spawn(async move { + loop { + select! { + t = text_stream.next() => { + match t { + Some(Ok(text)) => { + if let Err(_) = request_tx.send(SpeakWsMessage::Speak { text }).await { + break; + } + } + Some(Err(_err)) => { + break; + } + None => { + //when the text input stream closes, queue a close command + //on the websocket channel + let _ = request_tx.send(SpeakWsMessage::Close).await; + } + } + } + r = response_rx.next() => { + if r.is_none() { + break; + } + } + } + } + }); + + let audio_stream = SpeakAudioStream { + rx: handle.audio_rx, + }; + + Ok(audio_stream) + } +} + +/// TODO docs +#[derive(Debug)] +pub struct WebsocketHandle { + message_tx: mpsc::Sender, + response_rx: mpsc::Receiver>, + audio_rx: mpsc::Receiver>, + request_id: Uuid, +} + +impl WebsocketHandle { + async fn new(builder: WebsocketBuilder<'_>) -> Result { + let url = builder.as_url()?; + let host = url.host_str().ok_or(DeepgramError::InvalidUrl)?; + + let request = { + let http_builder = Request::builder() + .method("GET") + .uri(url.to_string()) + .header("sec-websocket-key", client::generate_key()) + .header("host", host) + .header("connection", "upgrade") + .header("upgrade", "websocket") + .header("sec-websocket-version", "13"); + + let builder = if let Some(auth) = &builder.deepgram.auth { + http_builder.header("authorization", auth.header_value()) + } else { + http_builder + }; + builder.body(())? + }; + + let (ws_stream, upgrade_response) = + tokio_tungstenite::connect_async_with_config(request, None, !builder.no_delay).await?; + + let request_id = upgrade_response + .headers() + .get("dg-request-id") + .ok_or(DeepgramError::UnexpectedServerResponse(anyhow!( + "Websocket upgrade headers missing request ID" + )))? + .to_str() + .ok() + .and_then(|req_header_str| Uuid::parse_str(req_header_str).ok()) + .ok_or(DeepgramError::UnexpectedServerResponse(anyhow!( + "Received malformed request ID in websocket upgrade headers" + )))?; + + let (message_tx, message_rx) = mpsc::channel(256); + let (response_tx, response_rx) = mpsc::channel(256); + let (audio_tx, audio_rx) = mpsc::channel(256); + + tokio::task::spawn({ + let worker = WsWorker::new(ws_stream, message_rx, response_tx, audio_tx); + + async move { + if let Err(err) = worker.run().await { + tracing::error!("speak websocket worker error: {:?}", err); + } + } + }); + + Ok(WebsocketHandle { + message_tx, + response_rx, + audio_rx, + request_id, + }) + } + + pub fn request_id(&self) -> Uuid { + self.request_id + } + + pub async fn send_text(&self, text: String) -> Result<()> { + if let Err(_) = self.message_tx.send(SpeakWsMessage::Speak { text }).await { + return Err(DeepgramError::UnexpectedServerResponse(anyhow!( + "websocket closed" + ))); + } + + Ok(()) + } + + pub async fn flush(&self) -> Result<()> { + let _ = self.message_tx.send(SpeakWsMessage::Flush).await; + Ok(()) + } +} + +#[derive(Debug)] +pub struct SpeakAudioStream { + rx: mpsc::Receiver>, +} + +impl Stream for SpeakAudioStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().rx.poll_recv(cx) + } +} + +impl<'a> Speak<'a> { + /// Opens a websocket connection to the Deepgram API to birectionally + /// stream text input and audio output + pub fn continuous_speak_to_stream(&self) -> WebsocketBuilder<'_> { + WebsocketBuilder { + deepgram: self.0, + encoding: None, + model: None, + sample_rate: None, + no_delay: false, + } + } +} + +/// TODO docs +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +pub enum SpeakWsMessage { + Speak { text: String }, + Flush, + Clear, + Close, +} + +/// TODO docs +#[derive(Debug)] +pub enum StreamResponse { + Audio(Bytes), + Control(SpeakResponse), +} + +/// TODO docs +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +pub enum SpeakResponse { + Flush { + sequence_id: u64, + }, + Clear { + sequence_id: u64, + }, + Close { + sequence_id: u64, + }, + StreamClosed { + code: u64, + reason: Option, + }, + Metadata { + request_id: String, + model_name: String, + }, +} + +#[derive(Debug)] +pub struct WsWorker { + ws_stream: WebSocketStream>, + request_rx: mpsc::Receiver, + response_tx: mpsc::Sender>, + audio_tx: mpsc::Sender>, +} + +impl WsWorker { + pub fn new( + ws_stream: WebSocketStream>, + request_rx: mpsc::Receiver, + response_tx: mpsc::Sender>, + audio_tx: mpsc::Sender>, + ) -> Self { + Self { + ws_stream, + request_rx, + response_tx, + audio_tx, + } + } + + async fn run(self) -> Result<()> { + let (mut ws_stream_send, ws_stream_recv) = self.ws_stream.split(); + let mut ws_recv = ws_stream_recv.fuse(); + let mut request_rx = ReceiverStream::new(self.request_rx).fuse(); + + loop { + select! { + response = ws_recv.next() => { + match response { + Some(Ok(Message::Text(response))) => { + eprintln!("Received text: {}", response); + match serde_json::from_str::(&response) { + Ok(response) => { + if (self.response_tx.send(Ok(response)).await).is_err() { + break; + } + } + Err(err) => { + if (self.response_tx.send(Err(err.into())).await).is_err() { + break; + } + } + } + } + Some(Ok(Message::Binary(audio))) => { + // eprintln!("Received audio"); + if (self.audio_tx.send(Ok(audio)).await).is_err() { + break; + } + } + Some(Ok(Message::Close(_))) => { + return Ok(()) + } + Some(Ok(Message::Ping(ping))) => { + // We don't really care if the server receives the pong. + let _ = ws_stream_send.send(Message::Pong(ping)).await; + } + Some(Ok(Message::Pong(_))) => { } + Some(Ok(Message::Frame(_))) => { + // eprintln!("Received frame"); + // We don't care about frames (I think). + } + Some(Err(err)) => { + if (self.response_tx.send(Err(err.into())).await).is_err() { + break; + } + } + None => { + return Ok(()) + } + } + } + + request = request_rx.next() => { + match request { + Some(request) => { + let msg = serde_json::to_string(&request)?; + eprintln!("Sending message: {}", msg); + if let Err(_) = ws_stream_send.send(Message::Text(msg.into())).await { + break; + } + } + None => { + return Ok(()) + } + } + } + } + } + + Ok(()) + } +}