diff --git a/crates/transcribe-proxy/src/config.rs b/crates/transcribe-proxy/src/config.rs index 3b4eff60c3..f07e3a1a27 100644 --- a/crates/transcribe-proxy/src/config.rs +++ b/crates/transcribe-proxy/src/config.rs @@ -14,6 +14,7 @@ pub struct SttProxyConfig { pub default_provider: Provider, pub connect_timeout: Duration, pub analytics: Option>, + pub upstream_urls: HashMap, } impl SttProxyConfig { @@ -23,6 +24,7 @@ impl SttProxyConfig { default_provider: Provider::Deepgram, connect_timeout: Duration::from_millis(DEFAULT_CONNECT_TIMEOUT_MS), analytics: None, + upstream_urls: HashMap::new(), } } @@ -41,7 +43,16 @@ impl SttProxyConfig { self } + pub fn with_upstream_url(mut self, provider: Provider, url: impl Into) -> Self { + self.upstream_urls.insert(provider, url.into()); + self + } + pub fn api_key_for(&self, provider: Provider) -> Option<&str> { self.api_keys.get(&provider).map(|s| s.as_str()) } + + pub fn upstream_url_for(&self, provider: Provider) -> Option<&str> { + self.upstream_urls.get(&provider).map(|s| s.as_str()) + } } diff --git a/crates/transcribe-proxy/src/routes/streaming.rs b/crates/transcribe-proxy/src/routes/streaming.rs index 9b04a10ece..00936e46d3 100644 --- a/crates/transcribe-proxy/src/routes/streaming.rs +++ b/crates/transcribe-proxy/src/routes/streaming.rs @@ -38,20 +38,24 @@ pub async fn handler( let provider = resolved.provider(); - let proxy = match provider.auth() { - Auth::SessionInit { header_name } => { - let url = match init_session(&state, &resolved, header_name, ¶ms).await { - Ok(url) => url, - Err(e) => { - tracing::error!(error = %e, "failed to init session"); - return (StatusCode::BAD_GATEWAY, e).into_response(); - } - }; - build_proxy_with_url(&resolved, &url, &state.config) - } - _ => { - let base = url::Url::parse(&provider.default_ws_url()).unwrap(); - build_proxy_with_components(&resolved, base, params, &state.config) + let proxy = if let Some(custom_url) = state.config.upstream_url_for(provider) { + build_proxy_with_url(&resolved, custom_url, &state.config) + } else { + match provider.auth() { + Auth::SessionInit { header_name } => { + let url = match init_session(&state, &resolved, header_name, ¶ms).await { + Ok(url) => url, + Err(e) => { + tracing::error!(error = %e, "failed to init session"); + return (StatusCode::BAD_GATEWAY, e).into_response(); + } + }; + build_proxy_with_url(&resolved, &url, &state.config) + } + _ => { + let base = url::Url::parse(&provider.default_ws_url()).unwrap(); + build_proxy_with_components(&resolved, base, params, &state.config) + } } }; @@ -114,6 +118,32 @@ async fn init_session( Ok(init.url) } +macro_rules! finalize_proxy_builder { + ($builder:expr, $provider:expr, $config:expr) => { + match &$config.analytics { + Some(analytics) => { + let analytics = analytics.clone(); + let provider_name = format!("{:?}", $provider).to_lowercase(); + $builder + .on_close(move |duration| { + let analytics = analytics.clone(); + let provider_name = provider_name.clone(); + async move { + analytics + .report_stt(SttEvent { + provider: provider_name, + duration, + }) + .await; + } + }) + .build() + } + None => $builder.build(), + } + }; +} + fn build_proxy_with_url( resolved: &ResolvedProvider, upstream_url: &str, @@ -126,27 +156,7 @@ fn build_proxy_with_url( .control_message_types(provider.control_message_types()) .apply_auth(resolved); - match &config.analytics { - Some(analytics) => { - let analytics = analytics.clone(); - let provider_name = format!("{:?}", provider).to_lowercase(); - builder - .on_close(move |duration| { - let analytics = analytics.clone(); - let provider_name = provider_name.clone(); - async move { - analytics - .report_stt(SttEvent { - provider: provider_name, - duration, - }) - .await; - } - }) - .build() - } - None => builder.build(), - } + finalize_proxy_builder!(builder, provider, config) } fn build_proxy_with_components( @@ -162,25 +172,5 @@ fn build_proxy_with_components( .control_message_types(provider.control_message_types()) .apply_auth(resolved); - match &config.analytics { - Some(analytics) => { - let analytics = analytics.clone(); - let provider_name = format!("{:?}", provider).to_lowercase(); - builder - .on_close(move |duration| { - let analytics = analytics.clone(); - let provider_name = provider_name.clone(); - async move { - analytics - .report_stt(SttEvent { - provider: provider_name, - duration, - }) - .await; - } - }) - .build() - } - None => builder.build(), - } + finalize_proxy_builder!(builder, provider, config) } diff --git a/crates/transcribe-proxy/tests/common/fixtures.rs b/crates/transcribe-proxy/tests/common/fixtures.rs new file mode 100644 index 0000000000..1d30c36301 --- /dev/null +++ b/crates/transcribe-proxy/tests/common/fixtures.rs @@ -0,0 +1,14 @@ +use std::path::PathBuf; + +use super::recording::WsRecording; + +pub fn fixtures_dir() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("fixtures") +} + +pub fn load_fixture(name: &str) -> WsRecording { + let path = fixtures_dir().join(name); + WsRecording::from_jsonl_file(&path).unwrap() +} diff --git a/crates/transcribe-proxy/tests/common/mock_upstream.rs b/crates/transcribe-proxy/tests/common/mock_upstream.rs new file mode 100644 index 0000000000..0da25e5faf --- /dev/null +++ b/crates/transcribe-proxy/tests/common/mock_upstream.rs @@ -0,0 +1,199 @@ +use std::net::SocketAddr; +use std::time::Duration; + +use futures_util::{SinkExt, StreamExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::protocol::CloseFrame; +use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; +use tokio_tungstenite::{WebSocketStream, accept_async}; + +use super::recording::{MessageKind, WsMessage, WsRecording}; + +#[derive(Debug, Clone)] +pub struct MockUpstreamConfig { + pub use_timing: bool, + pub max_delay_ms: u64, +} + +impl Default for MockUpstreamConfig { + fn default() -> Self { + Self { + use_timing: false, + max_delay_ms: 1000, + } + } +} + +impl MockUpstreamConfig { + pub fn use_timing(mut self, use_timing: bool) -> Self { + self.use_timing = use_timing; + self + } + + pub fn max_delay_ms(mut self, max_delay_ms: u64) -> Self { + self.max_delay_ms = max_delay_ms; + self + } +} + +struct MockUpstreamServer { + recording: WsRecording, + config: MockUpstreamConfig, + listener: TcpListener, +} + +impl MockUpstreamServer { + async fn with_config( + recording: WsRecording, + config: MockUpstreamConfig, + ) -> std::io::Result { + let listener = TcpListener::bind("127.0.0.1:0").await?; + Ok(Self { + recording, + config, + listener, + }) + } + + fn addr(&self) -> SocketAddr { + self.listener.local_addr().unwrap() + } + + async fn accept_one(&self) -> Result<(), MockUpstreamError> { + let (stream, _) = self.listener.accept().await?; + let ws_stream = accept_async(stream).await?; + self.handle_connection(ws_stream).await + } + + async fn handle_connection( + &self, + ws_stream: WebSocketStream, + ) -> Result<(), MockUpstreamError> { + let (mut sender, mut receiver) = ws_stream.split(); + + let server_messages: Vec<&WsMessage> = self + .recording + .messages + .iter() + .filter(|m| m.is_from_upstream()) + .collect(); + + let mut last_timestamp = 0u64; + let mut msg_index = 0; + + loop { + if msg_index >= server_messages.len() { + break; + } + + let msg = server_messages[msg_index]; + + if self.config.use_timing && msg.timestamp_ms > last_timestamp { + let delay = (msg.timestamp_ms - last_timestamp).min(self.config.max_delay_ms); + tokio::time::sleep(Duration::from_millis(delay)).await; + } + last_timestamp = msg.timestamp_ms; + + let ws_msg = ws_message_from_recorded(msg)?; + let is_close = matches!(msg.kind, MessageKind::Close { .. }); + + sender.send(ws_msg).await?; + msg_index += 1; + + if is_close { + break; + } + + while let Ok(Some(_)) = + tokio::time::timeout(Duration::from_millis(1), receiver.next()).await + {} + } + + Ok(()) + } +} + +fn ws_message_from_recorded(msg: &WsMessage) -> Result { + match &msg.kind { + MessageKind::Text => Ok(Message::Text(msg.content.clone().into())), + MessageKind::Binary => { + let data = msg.decode_binary()?; + Ok(Message::Binary(data.into())) + } + MessageKind::Close { code, reason } => Ok(Message::Close(Some(CloseFrame { + code: CloseCode::from(*code), + reason: reason.clone().into(), + }))), + MessageKind::Ping => { + let data = if msg.content.is_empty() { + vec![] + } else { + msg.decode_binary()? + }; + Ok(Message::Ping(data.into())) + } + MessageKind::Pong => { + let data = if msg.content.is_empty() { + vec![] + } else { + msg.decode_binary()? + }; + Ok(Message::Pong(data.into())) + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum MockUpstreamError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("WebSocket error: {0}")] + WebSocket(#[from] tokio_tungstenite::tungstenite::Error), + #[error("Base64 decode error: {0}")] + Base64(#[from] base64::DecodeError), +} + +pub struct MockServerHandle { + addr: SocketAddr, + #[allow(dead_code)] + shutdown_tx: tokio::sync::oneshot::Sender<()>, +} + +impl MockServerHandle { + pub fn ws_url(&self) -> String { + format!("ws://{}", self.addr) + } +} + +/// Starts a mock upstream server that replays recorded WebSocket messages. +/// +/// Note: This server only accepts a single connection. After one client connects +/// and the recording is replayed, the server will shut down. This is intentional +/// for test isolation - each test should create its own mock server instance. +pub async fn start_mock_server_with_config( + recording: WsRecording, + config: MockUpstreamConfig, +) -> std::io::Result { + let server = MockUpstreamServer::with_config(recording, config).await?; + let addr = server.addr(); + + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + + tokio::spawn(async move { + tokio::select! { + result = server.accept_one() => { + if let Err(e) = result { + tracing::warn!("mock_server_error: {:?}", e); + } + } + _ = shutdown_rx => { + tracing::debug!("mock_server_shutdown"); + } + } + }); + + tokio::time::sleep(Duration::from_millis(10)).await; + + Ok(MockServerHandle { addr, shutdown_tx }) +} diff --git a/crates/transcribe-proxy/tests/common/mod.rs b/crates/transcribe-proxy/tests/common/mod.rs index fbde97b7cf..19797d9f77 100644 --- a/crates/transcribe-proxy/tests/common/mod.rs +++ b/crates/transcribe-proxy/tests/common/mod.rs @@ -1,5 +1,16 @@ #![allow(dead_code)] +pub mod fixtures; +pub mod mock_upstream; +pub mod recording; + +#[allow(unused_imports)] +pub use fixtures::load_fixture; +#[allow(unused_imports)] +pub use mock_upstream::{MockServerHandle, MockUpstreamConfig, start_mock_server_with_config}; +#[allow(unused_imports)] +pub use recording::{Direction, MessageKind, WsMessage, WsRecording}; + use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; @@ -47,6 +58,17 @@ pub async fn start_server_with_provider(provider: Provider, api_key: String) -> start_server(config).await } +pub async fn start_server_with_upstream_url(provider: Provider, upstream_url: &str) -> SocketAddr { + let mut api_keys = HashMap::new(); + api_keys.insert(provider, "mock-api-key".to_string()); + + let config = SttProxyConfig::new(api_keys) + .with_default_provider(provider) + .with_upstream_url(provider, upstream_url); + + start_server(config).await +} + pub fn test_audio_stream() -> impl futures_util::Stream< Item = owhisper_interface::MixedMessage, > + Send diff --git a/crates/transcribe-proxy/tests/common/recording.rs b/crates/transcribe-proxy/tests/common/recording.rs new file mode 100644 index 0000000000..c0617bbe60 --- /dev/null +++ b/crates/transcribe-proxy/tests/common/recording.rs @@ -0,0 +1,273 @@ +use std::io::{BufRead, BufReader, Write}; +use std::path::Path; +use std::sync::{Arc, Mutex}; +use std::time::Instant; + +use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; +use serde::{Deserialize, Serialize}; + +use owhisper_providers::Provider; + +fn encode_optional_binary(data: &[u8]) -> String { + if data.is_empty() { + String::new() + } else { + BASE64.encode(data) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Direction { + ServerToClient, + ClientToServer, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum MessageKind { + Text, + Binary, + Close { code: u16, reason: String }, + Ping, + Pong, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WsMessage { + pub direction: Direction, + pub timestamp_ms: u64, + pub kind: MessageKind, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub content: String, +} + +impl WsMessage { + pub fn text(direction: Direction, timestamp_ms: u64, content: impl Into) -> Self { + Self { + direction, + timestamp_ms, + kind: MessageKind::Text, + content: content.into(), + } + } + + pub fn binary(direction: Direction, timestamp_ms: u64, data: &[u8]) -> Self { + Self { + direction, + timestamp_ms, + kind: MessageKind::Binary, + content: BASE64.encode(data), + } + } + + pub fn close( + direction: Direction, + timestamp_ms: u64, + code: u16, + reason: impl Into, + ) -> Self { + Self { + direction, + timestamp_ms, + kind: MessageKind::Close { + code, + reason: reason.into(), + }, + content: String::new(), + } + } + + pub fn ping(direction: Direction, timestamp_ms: u64, data: &[u8]) -> Self { + Self { + direction, + timestamp_ms, + kind: MessageKind::Ping, + content: encode_optional_binary(data), + } + } + + pub fn pong(direction: Direction, timestamp_ms: u64, data: &[u8]) -> Self { + Self { + direction, + timestamp_ms, + kind: MessageKind::Pong, + content: encode_optional_binary(data), + } + } + + pub fn decode_binary(&self) -> Result, base64::DecodeError> { + BASE64.decode(&self.content) + } + + pub fn is_from_upstream(&self) -> bool { + self.direction == Direction::ServerToClient + } + + pub fn is_to_upstream(&self) -> bool { + self.direction == Direction::ClientToServer + } +} + +#[derive(Debug, Clone, Default)] +pub struct WsRecording { + pub messages: Vec, +} + +impl WsRecording { + pub fn from_jsonl_file(path: impl AsRef) -> std::io::Result { + let file = std::fs::File::open(path)?; + let reader = BufReader::new(file); + Self::from_reader(reader) + } + + #[allow(dead_code)] + pub fn from_jsonl_str(jsonl: &str) -> std::io::Result { + Self::from_reader(jsonl.as_bytes()) + } + + pub fn from_reader(reader: R) -> std::io::Result { + let mut messages = Vec::new(); + for line in reader.lines() { + let line = line?; + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed.starts_with('#') { + continue; + } + let msg: WsMessage = serde_json::from_str(trimmed) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + messages.push(msg); + } + Ok(Self { messages }) + } + + pub fn to_jsonl_file(&self, path: impl AsRef) -> std::io::Result<()> { + let mut file = std::fs::File::create(path)?; + for msg in &self.messages { + let line = serde_json::to_string(msg) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + writeln!(file, "{}", line)?; + } + Ok(()) + } + + pub fn server_messages(&self) -> impl Iterator { + self.messages.iter().filter(|m| m.is_from_upstream()) + } + + pub fn push(&mut self, message: WsMessage) { + self.messages.push(message); + } + + pub fn transform(mut self, f: F) -> Self + where + F: Fn(WsMessage) -> WsMessage, + { + self.messages = self.messages.into_iter().map(f).collect(); + self + } +} + +#[derive(Debug)] +pub struct WsRecorder { + start_time: Instant, + recording: WsRecording, +} + +impl Default for WsRecorder { + fn default() -> Self { + Self { + start_time: Instant::now(), + recording: WsRecording::default(), + } + } +} + +impl WsRecorder { + pub fn elapsed_ms(&self) -> u64 { + self.start_time.elapsed().as_millis() as u64 + } + + pub fn record_text(&mut self, direction: Direction, content: impl Into) { + let msg = WsMessage::text(direction, self.elapsed_ms(), content); + self.recording.push(msg); + } + + #[allow(dead_code)] + pub fn record_close(&mut self, direction: Direction, code: u16, reason: impl Into) { + let msg = WsMessage::close(direction, self.elapsed_ms(), code, reason); + self.recording.push(msg); + } + + pub fn recording(&self) -> &WsRecording { + &self.recording + } +} + +#[derive(Clone)] +pub struct RecordingSession { + recorder: Arc>, + provider: Provider, +} + +impl RecordingSession { + pub fn new(provider: Provider) -> Self { + Self { + recorder: Arc::new(Mutex::new(WsRecorder::default())), + provider, + } + } + + pub fn record_server_text(&self, content: &str) { + let mut recorder = self.recorder.lock().unwrap(); + recorder.record_text(Direction::ServerToClient, content); + } + + pub fn save_to_file(&self, dir: impl AsRef, suffix: &str) -> std::io::Result<()> { + let recorder = self.recorder.lock().unwrap(); + let recording = recorder.recording(); + + let filename = format!( + "{}_{}.jsonl", + self.provider.to_string().to_lowercase(), + suffix + ); + let path = dir.as_ref().join(filename); + + recording.to_jsonl_file(path) + } +} + +pub struct RecordingOptions { + pub enabled: bool, + pub output_dir: Option, + pub suffix: String, +} + +impl RecordingOptions { + /// Check if recording is enabled via environment variable. + /// Set RECORD_FIXTURES=1 to enable recording during live tests. + pub fn from_env(suffix: impl Into) -> Self { + let enabled = std::env::var("RECORD_FIXTURES") + .map(|v| v == "1" || v.to_lowercase() == "true") + .unwrap_or(false); + + if enabled { + Self { + enabled: true, + output_dir: Some( + std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("fixtures"), + ), + suffix: suffix.into(), + } + } else { + Self { + enabled: false, + output_dir: None, + suffix: suffix.into(), + } + } + } +} diff --git a/crates/transcribe-proxy/tests/fixtures/deepgram_auth_error.jsonl b/crates/transcribe-proxy/tests/fixtures/deepgram_auth_error.jsonl new file mode 100644 index 0000000000..202a436294 --- /dev/null +++ b/crates/transcribe-proxy/tests/fixtures/deepgram_auth_error.jsonl @@ -0,0 +1,2 @@ +{"direction":"server_to_client","timestamp_ms":50,"kind":{"type":"text"},"content":"{\"err_code\":\"INVALID_AUTH\",\"err_msg\":\"Invalid credentials.\",\"request_id\":\"test-request-id\"}"} +{"direction":"server_to_client","timestamp_ms":100,"kind":{"type":"close","code":1008,"reason":"policy violation"}} diff --git a/crates/transcribe-proxy/tests/fixtures/deepgram_normal.jsonl b/crates/transcribe-proxy/tests/fixtures/deepgram_normal.jsonl new file mode 100644 index 0000000000..b4900ab2c9 --- /dev/null +++ b/crates/transcribe-proxy/tests/fixtures/deepgram_normal.jsonl @@ -0,0 +1,7 @@ +{"direction":"server_to_client","timestamp_ms":50,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":0.0,\"start\":0.0,\"is_final\":false,\"speech_final\":false,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"\",\"confidence\":0.0,\"words\":[]}]},\"metadata\":{\"request_id\":\"test-request-id\",\"model_info\":{\"name\":\"nova-3\",\"version\":\"2024-01-01\"},\"model_uuid\":\"test-model-uuid\"}}"} +{"direction":"server_to_client","timestamp_ms":500,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":0.5,\"start\":0.0,\"is_final\":false,\"speech_final\":false,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"Hello\",\"confidence\":0.95,\"words\":[{\"word\":\"Hello\",\"start\":0.1,\"end\":0.4,\"confidence\":0.95}]}]},\"metadata\":{\"request_id\":\"test-request-id\",\"model_info\":{\"name\":\"nova-3\",\"version\":\"2024-01-01\"},\"model_uuid\":\"test-model-uuid\"}}"} +{"direction":"server_to_client","timestamp_ms":1000,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":1.0,\"start\":0.0,\"is_final\":true,\"speech_final\":true,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"Hello world\",\"confidence\":0.97,\"words\":[{\"word\":\"Hello\",\"start\":0.1,\"end\":0.4,\"confidence\":0.95},{\"word\":\"world\",\"start\":0.5,\"end\":0.9,\"confidence\":0.98}]}]},\"metadata\":{\"request_id\":\"test-request-id\",\"model_info\":{\"name\":\"nova-3\",\"version\":\"2024-01-01\"},\"model_uuid\":\"test-model-uuid\"}}"} +{"direction":"server_to_client","timestamp_ms":1500,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":1.5,\"start\":1.0,\"is_final\":false,\"speech_final\":false,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"This is\",\"confidence\":0.92,\"words\":[{\"word\":\"This\",\"start\":1.1,\"end\":1.3,\"confidence\":0.91},{\"word\":\"is\",\"start\":1.35,\"end\":1.5,\"confidence\":0.93}]}]},\"metadata\":{\"request_id\":\"test-request-id\",\"model_info\":{\"name\":\"nova-3\",\"version\":\"2024-01-01\"},\"model_uuid\":\"test-model-uuid\"}}"} +{"direction":"server_to_client","timestamp_ms":2000,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":2.0,\"start\":1.0,\"is_final\":true,\"speech_final\":true,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"This is a test\",\"confidence\":0.96,\"words\":[{\"word\":\"This\",\"start\":1.1,\"end\":1.3,\"confidence\":0.91},{\"word\":\"is\",\"start\":1.35,\"end\":1.5,\"confidence\":0.93},{\"word\":\"a\",\"start\":1.55,\"end\":1.6,\"confidence\":0.98},{\"word\":\"test\",\"start\":1.65,\"end\":1.95,\"confidence\":0.99}]}]},\"metadata\":{\"request_id\":\"test-request-id\",\"model_info\":{\"name\":\"nova-3\",\"version\":\"2024-01-01\"},\"model_uuid\":\"test-model-uuid\"}}"} +{"direction":"server_to_client","timestamp_ms":2500,"kind":{"type":"text"},"content":"{\"type\":\"Metadata\",\"request_id\":\"test-request-id\",\"created\":\"2024-01-15T10:30:00.000Z\",\"duration\":2.5,\"channels\":1}"} +{"direction":"server_to_client","timestamp_ms":2600,"kind":{"type":"close","code":1000,"reason":"normal closure"}} diff --git a/crates/transcribe-proxy/tests/fixtures/deepgram_rate_limit.jsonl b/crates/transcribe-proxy/tests/fixtures/deepgram_rate_limit.jsonl new file mode 100644 index 0000000000..e33c65fa3e --- /dev/null +++ b/crates/transcribe-proxy/tests/fixtures/deepgram_rate_limit.jsonl @@ -0,0 +1,2 @@ +{"direction":"server_to_client","timestamp_ms":50,"kind":{"type":"text"},"content":"{\"err_code\":\"TOO_MANY_REQUESTS\",\"err_msg\":\"Too many requests. Please try again later\",\"request_id\":\"test-request-id\"}"} +{"direction":"server_to_client","timestamp_ms":100,"kind":{"type":"close","code":1008,"reason":"rate limit exceeded"}} diff --git a/crates/transcribe-proxy/tests/fixtures/soniox_error.jsonl b/crates/transcribe-proxy/tests/fixtures/soniox_error.jsonl new file mode 100644 index 0000000000..fc3627fad4 --- /dev/null +++ b/crates/transcribe-proxy/tests/fixtures/soniox_error.jsonl @@ -0,0 +1,2 @@ +{"direction":"server_to_client","timestamp_ms":50,"kind":{"type":"text"},"content":"{\"error_code\":503,\"error_message\":\"Cannot continue request (code 1). Please restart the request.\"}"} +{"direction":"server_to_client","timestamp_ms":100,"kind":{"type":"close","code":1011,"reason":"server error"}} diff --git a/crates/transcribe-proxy/tests/fixtures/soniox_normal.jsonl b/crates/transcribe-proxy/tests/fixtures/soniox_normal.jsonl new file mode 100644 index 0000000000..511d691b38 --- /dev/null +++ b/crates/transcribe-proxy/tests/fixtures/soniox_normal.jsonl @@ -0,0 +1,7 @@ +{"direction":"server_to_client","timestamp_ms":100,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":0.0,\"start\":0.0,\"is_final\":false,\"speech_final\":false,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"\",\"confidence\":0.0,\"words\":[]}]},\"metadata\":{}}"} +{"direction":"server_to_client","timestamp_ms":600,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":0.6,\"start\":0.0,\"is_final\":false,\"speech_final\":false,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"Hello\",\"confidence\":0.93,\"words\":[{\"word\":\"Hello\",\"start\":0.15,\"end\":0.45,\"confidence\":0.93}]}]},\"metadata\":{}}"} +{"direction":"server_to_client","timestamp_ms":1100,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":1.1,\"start\":0.0,\"is_final\":true,\"speech_final\":true,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"Hello world\",\"confidence\":0.96,\"words\":[{\"word\":\"Hello\",\"start\":0.15,\"end\":0.45,\"confidence\":0.93},{\"word\":\"world\",\"start\":0.55,\"end\":0.95,\"confidence\":0.97}]}]},\"metadata\":{}}"} +{"direction":"server_to_client","timestamp_ms":1600,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":1.6,\"start\":1.1,\"is_final\":false,\"speech_final\":false,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"Testing\",\"confidence\":0.91,\"words\":[{\"word\":\"Testing\",\"start\":1.2,\"end\":1.55,\"confidence\":0.91}]}]},\"metadata\":{}}"} +{"direction":"server_to_client","timestamp_ms":2100,"kind":{"type":"text"},"content":"{\"type\":\"Results\",\"channel_index\":[0,1],\"duration\":2.1,\"start\":1.1,\"is_final\":true,\"speech_final\":true,\"from_finalize\":false,\"channel\":{\"alternatives\":[{\"transcript\":\"Testing Soniox\",\"confidence\":0.95,\"words\":[{\"word\":\"Testing\",\"start\":1.2,\"end\":1.55,\"confidence\":0.91},{\"word\":\"Soniox\",\"start\":1.6,\"end\":2.0,\"confidence\":0.98}]}]},\"metadata\":{}}"} +{"direction":"server_to_client","timestamp_ms":2200,"kind":{"type":"text"},"content":"{\"type\":\"Metadata\",\"request_id\":\"\",\"created\":\"\",\"duration\":2.2,\"channels\":1}"} +{"direction":"server_to_client","timestamp_ms":2300,"kind":{"type":"close","code":1000,"reason":"normal closure"}} diff --git a/crates/transcribe-proxy/tests/providers_e2e.rs b/crates/transcribe-proxy/tests/providers_e2e.rs index a7270e0eb9..4b7a12205b 100644 --- a/crates/transcribe-proxy/tests/providers_e2e.rs +++ b/crates/transcribe-proxy/tests/providers_e2e.rs @@ -1,4 +1,6 @@ mod common; + +use common::recording::{RecordingOptions, RecordingSession}; use common::*; use futures_util::StreamExt; @@ -11,6 +13,15 @@ use owhisper_providers::Provider; async fn run_proxy_live_test( provider: Provider, params: owhisper_interface::ListenParams, +) { + run_proxy_live_test_with_recording::(provider, params, RecordingOptions::from_env("normal")) + .await +} + +async fn run_proxy_live_test_with_recording( + provider: Provider, + params: owhisper_interface::ListenParams, + recording_opts: RecordingOptions, ) { let _ = tracing_subscriber::fmt::try_init(); @@ -18,6 +29,12 @@ async fn run_proxy_live_test( .unwrap_or_else(|_| panic!("{} must be set", provider.env_key_name())); let addr = start_server_with_provider(provider, api_key).await; + let recording_session = if recording_opts.enabled { + Some(RecordingSession::new(provider)) + } else { + None + }; + let client = ListenClient::builder() .adapter::() .api_base(format!("http://{}", addr)) @@ -36,15 +53,26 @@ async fn run_proxy_live_test( let test_future = async { while let Some(result) = stream.next().await { match result { - Ok(StreamResponse::TranscriptResponse { channel, .. }) => { - if let Some(alt) = channel.alternatives.first() { - if !alt.transcript.is_empty() { - println!("[{}] {}", provider_name, alt.transcript); - saw_transcript = true; + Ok(response) => { + // Record the response if recording is enabled + if let Some(ref session) = recording_session { + match serde_json::to_string(&response) { + Ok(json) => session.record_server_text(&json), + Err(e) => { + tracing::warn!("failed to serialize response for recording: {}", e) + } + } + } + + if let StreamResponse::TranscriptResponse { channel, .. } = &response { + if let Some(alt) = channel.alternatives.first() { + if !alt.transcript.is_empty() { + println!("[{}] {}", provider_name, alt.transcript); + saw_transcript = true; + } } } } - Ok(_) => {} Err(e) => { panic!("[{}] error: {:?}", provider_name, e); } @@ -55,6 +83,17 @@ async fn run_proxy_live_test( let _ = tokio::time::timeout(timeout, test_future).await; handle.finalize().await; + // Save recording if enabled + if let Some(session) = recording_session { + if let Some(ref output_dir) = recording_opts.output_dir { + std::fs::create_dir_all(output_dir).expect("failed to create fixtures directory"); + session + .save_to_file(output_dir, &recording_opts.suffix) + .expect("failed to save recording"); + println!("[{}] Recording saved to {:?}", provider_name, output_dir); + } + } + assert!( saw_transcript, "[{}] expected at least one non-empty transcript", diff --git a/crates/transcribe-proxy/tests/replay.rs b/crates/transcribe-proxy/tests/replay.rs new file mode 100644 index 0000000000..ae8e5dacea --- /dev/null +++ b/crates/transcribe-proxy/tests/replay.rs @@ -0,0 +1,258 @@ +mod common; + +use std::time::Duration; + +use futures_util::{SinkExt, StreamExt}; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Message; + +use common::{ + MessageKind, MockUpstreamConfig, load_fixture, start_mock_server_with_config, + start_server_with_upstream_url, +}; +use owhisper_providers::Provider; + +const TEST_RESPONSE_TIMEOUT: Duration = Duration::from_secs(5); + +async fn connect_to_proxy( + proxy_addr: std::net::SocketAddr, + model: &str, +) -> tokio_tungstenite::WebSocketStream> { + let url = format!( + "ws://{}/listen?model={}&encoding=linear16&sample_rate=16000&channels=1", + proxy_addr, model + ); + let (ws_stream, _) = connect_async(&url) + .await + .expect("Failed to connect to proxy"); + ws_stream +} + +async fn collect_messages( + ws_stream: tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + timeout: Duration, +) -> (Vec, Option<(u16, String)>) { + let (mut _sender, mut receiver) = ws_stream.split(); + let mut messages = Vec::new(); + let mut close_info = None; + + let collect_future = async { + while let Some(msg_result) = receiver.next().await { + match msg_result { + Ok(Message::Text(text)) => { + messages.push(text.to_string()); + } + Ok(Message::Close(frame)) => { + close_info = frame.map(|f| (f.code.into(), f.reason.to_string())); + break; + } + Ok(_) => {} + Err(e) => { + eprintln!("WebSocket error: {:?}", e); + break; + } + } + } + }; + + let _ = tokio::time::timeout(timeout, collect_future).await; + (messages, close_info) +} + +#[tokio::test] +async fn test_deepgram_normal_transcription_replay() { + let _ = tracing_subscriber::fmt::try_init(); + + let recording = load_fixture("deepgram_normal.jsonl"); + let mock_handle = start_mock_server_with_config(recording, MockUpstreamConfig::default()) + .await + .expect("Failed to start mock server"); + + let proxy_addr = + start_server_with_upstream_url(Provider::Deepgram, &mock_handle.ws_url()).await; + + let ws_stream = connect_to_proxy(proxy_addr, "nova-3").await; + let (messages, close_info) = collect_messages(ws_stream, TEST_RESPONSE_TIMEOUT).await; + + assert!(!messages.is_empty(), "Expected to receive messages"); + + let has_hello_world = messages.iter().any(|m| m.contains("Hello world")); + let has_test = messages.iter().any(|m| m.contains("This is a test")); + assert!(has_hello_world, "Expected 'Hello world' transcript"); + assert!(has_test, "Expected 'This is a test' transcript"); + + if let Some((code, _reason)) = close_info { + assert_eq!(code, 1000, "Expected normal close code 1000"); + } +} + +#[tokio::test] +async fn test_deepgram_auth_error_replay() { + let _ = tracing_subscriber::fmt::try_init(); + + let recording = load_fixture("deepgram_auth_error.jsonl"); + let mock_handle = start_mock_server_with_config(recording, MockUpstreamConfig::default()) + .await + .expect("Failed to start mock server"); + + let proxy_addr = + start_server_with_upstream_url(Provider::Deepgram, &mock_handle.ws_url()).await; + + let ws_stream = connect_to_proxy(proxy_addr, "nova-3").await; + let (messages, close_info) = collect_messages(ws_stream, TEST_RESPONSE_TIMEOUT).await; + + assert!(!messages.is_empty(), "Expected to receive error message"); + let has_auth_error = messages + .iter() + .any(|m| m.contains("INVALID_AUTH") || m.contains("Invalid credentials")); + assert!(has_auth_error, "Expected auth error message"); + + if let Some((code, _reason)) = close_info { + assert!( + code == 4401 || code == 1008, + "Expected close code 4401 or 1008, got {}", + code + ); + } +} + +#[tokio::test] +async fn test_deepgram_rate_limit_replay() { + let _ = tracing_subscriber::fmt::try_init(); + + let recording = load_fixture("deepgram_rate_limit.jsonl"); + let mock_handle = start_mock_server_with_config(recording, MockUpstreamConfig::default()) + .await + .expect("Failed to start mock server"); + + let proxy_addr = + start_server_with_upstream_url(Provider::Deepgram, &mock_handle.ws_url()).await; + + let ws_stream = connect_to_proxy(proxy_addr, "nova-3").await; + let (messages, close_info) = collect_messages(ws_stream, TEST_RESPONSE_TIMEOUT).await; + + let has_rate_limit = messages + .iter() + .any(|m| m.contains("TOO_MANY_REQUESTS") || m.contains("Too many requests")); + assert!(has_rate_limit, "Expected rate limit error message"); + + if let Some((code, _reason)) = close_info { + assert!( + code == 4429 || code == 1008, + "Expected close code 4429 or 1008, got {}", + code + ); + } +} + +#[tokio::test] +async fn test_soniox_normal_transcription_replay() { + let _ = tracing_subscriber::fmt::try_init(); + + let recording = load_fixture("soniox_normal.jsonl"); + let mock_handle = start_mock_server_with_config(recording, MockUpstreamConfig::default()) + .await + .expect("Failed to start mock server"); + + let proxy_addr = start_server_with_upstream_url(Provider::Soniox, &mock_handle.ws_url()).await; + + let ws_stream = connect_to_proxy(proxy_addr, "stt-v3").await; + let (messages, close_info) = collect_messages(ws_stream, TEST_RESPONSE_TIMEOUT).await; + + assert!(!messages.is_empty(), "Expected to receive messages"); + + let has_hello_world = messages.iter().any(|m| m.contains("Hello world")); + let has_soniox = messages.iter().any(|m| m.contains("Soniox")); + assert!(has_hello_world, "Expected 'Hello world' transcript"); + assert!(has_soniox, "Expected 'Soniox' transcript"); + + if let Some((code, _reason)) = close_info { + assert_eq!(code, 1000, "Expected normal close code 1000"); + } +} + +#[tokio::test] +async fn test_soniox_error_replay() { + let _ = tracing_subscriber::fmt::try_init(); + + let recording = load_fixture("soniox_error.jsonl"); + let mock_handle = start_mock_server_with_config(recording, MockUpstreamConfig::default()) + .await + .expect("Failed to start mock server"); + + let proxy_addr = start_server_with_upstream_url(Provider::Soniox, &mock_handle.ws_url()).await; + + let ws_stream = connect_to_proxy(proxy_addr, "stt-v3").await; + let (messages, close_info) = collect_messages(ws_stream, TEST_RESPONSE_TIMEOUT).await; + + let has_error = messages + .iter() + .any(|m| m.contains("error_code") || m.contains("Cannot continue request")); + assert!(has_error, "Expected error message"); + + if let Some((code, _reason)) = close_info { + assert!( + code == 4500 || code == 1011, + "Expected close code 4500 or 1011, got {}", + code + ); + } +} + +#[tokio::test] +async fn test_proxy_forwards_all_messages() { + let _ = tracing_subscriber::fmt::try_init(); + + let recording = load_fixture("deepgram_normal.jsonl"); + let expected_text_count = recording + .server_messages() + .filter(|m| matches!(m.kind, MessageKind::Text)) + .count(); + + let mock_handle = start_mock_server_with_config(recording, MockUpstreamConfig::default()) + .await + .expect("Failed to start mock server"); + + let proxy_addr = + start_server_with_upstream_url(Provider::Deepgram, &mock_handle.ws_url()).await; + + let ws_stream = connect_to_proxy(proxy_addr, "nova-3").await; + let (messages, _close_info) = collect_messages(ws_stream, TEST_RESPONSE_TIMEOUT).await; + + assert_eq!( + messages.len(), + expected_text_count, + "Expected {} messages, got {}", + expected_text_count, + messages.len() + ); +} + +#[tokio::test] +async fn test_proxy_handles_client_disconnect() { + let _ = tracing_subscriber::fmt::try_init(); + + let recording = load_fixture("deepgram_normal.jsonl"); + let mock_handle = start_mock_server_with_config( + recording, + MockUpstreamConfig::default() + .use_timing(true) + .max_delay_ms(100), + ) + .await + .expect("Failed to start mock server"); + + let proxy_addr = + start_server_with_upstream_url(Provider::Deepgram, &mock_handle.ws_url()).await; + + let ws_stream = connect_to_proxy(proxy_addr, "nova-3").await; + let (mut sender, mut receiver) = ws_stream.split(); + + if let Some(msg) = receiver.next().await { + assert!(msg.is_ok(), "Expected first message to succeed"); + } + + let _ = sender.send(Message::Close(None)).await; +}