diff --git a/Cargo.lock b/Cargo.lock index 8ae394d0fd..ea059f4793 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11936,6 +11936,7 @@ dependencies = [ "axum 0.7.9", "client-ip", "derive_more 1.0.0", + "flate2", "futures", "proptest", "proptest-derive 0.5.1", diff --git a/crates/full-node/sov-sequencer/src/rest_api.rs b/crates/full-node/sov-sequencer/src/rest_api.rs index cf3a515ced..13de078b4c 100644 --- a/crates/full-node/sov-sequencer/src/rest_api.rs +++ b/crates/full-node/sov-sequencer/src/rest_api.rs @@ -19,8 +19,9 @@ use sov_modules_api::{FullyBakedTx, RawTx, RuntimeEventProcessor, RuntimeEventRe use sov_rest_utils::handle_bad_ws_request; use sov_rest_utils::send_json; use sov_rest_utils::{ - errors, preconfigured_router_layers, serve_generic_ws_subscription, ApiResult, FilterQuery, - PageSelection, PaginatedResponse, Pagination, Path, Query, + errors, preconfigured_router_layers, serve_generic_ws_subscription, + serve_generic_ws_subscription_with_config, ApiResult, FilterQuery, PageSelection, + PaginatedResponse, Pagination, Path, Query, WsSubscriptionConfig, }; use sov_rest_utils::{get_client_ip, WsMessage}; use sov_rollup_interface::da::{DaBlobHash, DaSpec}; @@ -75,6 +76,34 @@ pub struct StartFrom { start_from: u64, } +/// Compression mode for WebSocket subscriptions. +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CompressionMode { + /// No compression (default). Messages are sent as individual JSON text frames. + #[default] + None, + /// Gzip compression. Messages are batched into JSON arrays, compressed, and sent as binary frames. + Gzip, +} + +/// Query parameters for WebSocket subscriptions that support compression. +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct CompressionQuery { + /// The compression mode to use for this subscription. + #[serde(default)] + pub compression: CompressionMode, +} + +impl CompressionQuery { + /// Converts the query into a [`WsSubscriptionConfig`]. + pub fn to_config(&self) -> WsSubscriptionConfig { + WsSubscriptionConfig { + compress: self.compression == CompressionMode::Gzip, + } + } +} + /// Provides REST APIs for any [`Sequencer`]. See [`SequencerApis::rest_api_server`]. #[derive(derivative::Derivative)] #[derivative(Clone(bound = ""))] @@ -480,10 +509,12 @@ impl SequencerApis { async fn subscribe_to_events( State(state): State, filter: FilterQuery, + compression: Option>, ws: WebSocketUpgrade, ) -> impl IntoResponse { use futures::future; - ws.on_upgrade(|socket| async move { + let config = compression.map(|q| q.0.to_config()).unwrap_or_default(); + ws.on_upgrade(move |socket| async move { let stream = state .sequencer .subscribe_events() @@ -494,20 +525,34 @@ impl SequencerApis { (Ok(event), Some(filter)) => future::ready(filter.matches(&event.key)), (_, _) => future::ready(true), }); - serve_generic_ws_subscription(socket, stream, state.shutdown_receiver.clone()).await; + serve_generic_ws_subscription_with_config( + socket, + stream, + state.shutdown_receiver.clone(), + config, + ) + .await; }) } async fn subscribe_to_transactions( State(state): State, start_from: Option>, + compression: Option>, ws: WebSocketUpgrade, ) -> impl IntoResponse { let start_from = start_from.map(|start_from| start_from.0.start_from); + let config = compression.map(|q| q.0.to_config()).unwrap_or_default(); ws.on_upgrade(move |socket| async move { let stream = Self::subscribe_txs_starting_from(start_from, state.sequencer.clone()).await; - serve_generic_ws_subscription(socket, stream, state.shutdown_receiver.clone()).await; + serve_generic_ws_subscription_with_config( + socket, + stream, + state.shutdown_receiver.clone(), + config, + ) + .await; }) } diff --git a/crates/utils/sov-rest-utils/Cargo.toml b/crates/utils/sov-rest-utils/Cargo.toml index 7144e288fd..8714063f40 100644 --- a/crates/utils/sov-rest-utils/Cargo.toml +++ b/crates/utils/sov-rest-utils/Cargo.toml @@ -14,6 +14,7 @@ anyhow = { workspace = true } client-ip = "0.1.1" axum = { workspace = true, features = ["query", "ws", "json", "original-uri"] } derive_more = { workspace = true, default-features = true, features = ["deref", "display"] } +flate2 = { version = "1", default-features = false, features = ["rust_backend"] } futures = { workspace = true } proptest = { workspace = true, features = ["std"], optional = true } proptest-derive = { workspace = true, optional = true } @@ -35,6 +36,7 @@ tokio-tungstenite = { version = "0.28.0", default-features = false, features = [ tokio-stream = { workspace = true } tungstenite = "0.28.0" axum = { workspace = true, features = ["query", "ws", "json", "original-uri", "tokio", "http1"] } +flate2 = { version = "1", default-features = false, features = ["rust_backend"] } [features] arbitrary = ["proptest", "proptest-derive", "sov-rest-utils/arbitrary"] diff --git a/crates/utils/sov-rest-utils/src/lib.rs b/crates/utils/sov-rest-utils/src/lib.rs index cbe303dc60..18dc08c6e0 100644 --- a/crates/utils/sov-rest-utils/src/lib.rs +++ b/crates/utils/sov-rest-utils/src/lib.rs @@ -180,6 +180,13 @@ pub fn cors_layer_opt( const MAX_BATCH_SIZE: usize = 128; +/// Configuration for WebSocket subscription behavior. +#[derive(Debug, Clone, Copy, Default)] +pub struct WsSubscriptionConfig { + /// When true, messages are batched into arrays and gzip-compressed before sending. + pub compress: bool, +} + /// Interval between ping frames sent to the client for keepalive. const PING_INTERVAL: std::time::Duration = std::time::Duration::from_secs(30); @@ -195,9 +202,42 @@ const PONG_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); /// - Graceful shutdown on server shutdown signal /// - Proper handling of client disconnection and half-closed connections pub async fn serve_generic_ws_subscription( + socket: WebSocket, + subscription: S, + shutdown_receiver: tokio::sync::watch::Receiver<()>, +) where + S: futures::Stream> + Unpin, + E: ReportableWsError, + M: Clone + serde::Serialize + Send + Sync + 'static, +{ + serve_generic_ws_subscription_with_config( + socket, + subscription, + shutdown_receiver, + WsSubscriptionConfig::default(), + ) + .await +} + +/// A utility function for serving some data inside a [`futures::Stream`] over a +/// WebSocket connection, with configurable behavior. +/// +/// This function handles: +/// - Sending data from the subscription stream to the client +/// - Periodic ping/pong keepalive to detect dead connections +/// - Graceful shutdown on server shutdown signal +/// - Proper handling of client disconnection and half-closed connections +/// - Optional gzip compression of batched messages (when `config.compress` is true) +/// +/// When compression is enabled: +/// - Data messages are batched into JSON arrays, gzip-compressed, and sent as binary frames +/// - Error messages are also gzip-compressed and sent as binary frames +/// - Clients can uniformly decompress all binary frames (gzip magic bytes: `0x1f 0x8b`) +pub async fn serve_generic_ws_subscription_with_config( mut socket: WebSocket, subscription: S, mut shutdown_receiver: tokio::sync::watch::Receiver<()>, + config: WsSubscriptionConfig, ) where S: futures::Stream> + Unpin, E: ReportableWsError, @@ -257,13 +297,19 @@ pub async fn serve_generic_ws_subscription( }, Some(Ok(_)) => { // Client sent an unexpected message - notify them it was ignored - trace!("Incoming WebSocket message but none was expected; notifying client"); - if let Err(err) = send_json(&mut socket, &ErrorObject { + let error = ErrorObject { status: StatusCode::BAD_REQUEST, message: "This subscription does not accept incoming messages".to_string(), details: JsonObject::new(), - }).await { - warn!(?err, "Failed to send error response - disconnecting client"); + }; + trace!("Incoming WebSocket message but none was expected; notifying client"); + if config.compress { + if let Err(e) = feed_compressed_bytes(&mut socket, serde_json::to_string(&error).expect("Failed to serialize error as JSON. This is a bug, please report it").as_bytes()).await { + warn!(?e, "Failed to send error response - disconnecting client"); + break; + } + } else if let Err(e) = send_json(&mut socket, &error).await { + warn!(?e, "Failed to send error response - disconnecting client"); break; } }, @@ -272,31 +318,68 @@ pub async fn serve_generic_ws_subscription( chunk_opt = chunked_subscription.next() => { match chunk_opt { Some(chunk) => { - for item in chunk { - match item { - Ok(data) => { - let serialized = match serde_json::to_string(&data) { - Ok(serialized) => serialized, - Err(err) => { - error!(?err, "Failed to serialize data for WebSocket; this is a bug, please report it"); + if config.compress { + // Compressed mode: batch successful items, compress, send as binary + let mut batch: Vec = Vec::with_capacity(chunk.len()); + for item in chunk { + match item { + Ok(data) => { + batch.push(data); + } + Err(err) => { + // Flush any accumulated batch before sending error + if !batch.is_empty() { + if send_compressed_batch(&mut socket, &batch).await.is_err() { + break 'outer; + } + batch.clear(); + } + + // Send compressed error + if feed_compressed_bytes(&mut socket, err.to_json().as_bytes()).await.is_err() { + break 'outer; + } + + if !err.is_recoverable() { + // Note that breaking out of the loop will also flush the socket, so we don't need to do it here. break 'outer; } - }; - if let Err(err) = socket.feed(serialized.into()).await { - warn!(?err, "WebSocket send error - disconnecting client"); - break 'outer; } } - Err(err) => { - // Send error notification to the client - if let Err(send_err) = socket.send(err.to_json().into()).await { - warn!(err=?send_err, "WebSocket send error - disconnecting client"); - break 'outer; + } + + // Send remaining batch + if !batch.is_empty() && send_compressed_batch(&mut socket, &batch).await.is_err() { + break 'outer; + } + } else { + // Uncompressed mode: send individual text messages (original behavior) + for item in chunk { + match item { + Ok(data) => { + let serialized = match serde_json::to_string(&data) { + Ok(serialized) => serialized, + Err(err) => { + error!(?err, "Failed to serialize data for WebSocket; this is a bug, please report it"); + break 'outer; + } + }; + if let Err(err) = socket.feed(serialized.into()).await { + warn!(?err, "WebSocket send error - disconnecting client"); + break 'outer; + } } - // For recoverable errors (e.g., lag), continue streaming - // For non-recoverable errors, disconnect - if !err.is_recoverable() { - break 'outer; + Err(err) => { + // Send error notification to the client + if let Err(send_err) = socket.send(err.to_json().into()).await { + warn!(err=?send_err, "WebSocket send error - disconnecting client"); + break 'outer; + } + // For recoverable errors (e.g., lag), continue streaming + // For non-recoverable errors, disconnect + if !err.is_recoverable() { + break 'outer; + } } } } @@ -343,6 +426,72 @@ pub async fn serve_generic_ws_subscription( socket.close().await.ok(); } +/// Compresses bytes with gzip. +fn compress_bytes(data: &[u8]) -> Result, std::io::Error> { + use flate2::write::GzEncoder; + use flate2::Compression; + use std::io::Write; + + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(data)?; + encoder.finish() +} + +/// Serializes the value as JSON and compresses it with gzip. +fn compress_json(value: &T) -> Result, std::io::Error> { + let json = serde_json::to_vec(value) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + compress_bytes(&json) +} + +/// Compresses and sends a batch of items as a binary WebSocket message. +/// Returns Err(()) if the send fails or compression fails. +async fn send_compressed_batch( + socket: &mut WebSocket, + batch: &[T], +) -> Result<(), ()> { + use axum::extract::ws::Message; + + match compress_json(batch) { + Ok(compressed) => { + if let Err(err) = socket.feed(Message::Binary(compressed)).await { + warn!(?err, "WebSocket send error - disconnecting client"); + return Err(()); + } + Ok(()) + } + Err(err) => { + error!( + ?err, + "Failed to serialize/compress data for WebSocket; this is a bug, please report it" + ); + Err(()) + } + } +} + +/// Compresses and sends raw bytes as a binary WebSocket message. +/// Returns Err(()) if the send fails or compression fails. +async fn feed_compressed_bytes(socket: &mut WebSocket, data: &[u8]) -> Result<(), anyhow::Error> { + use axum::extract::ws::Message; + + match compress_bytes(data) { + Ok(compressed) => { + if let Err(err) = socket.feed(Message::Binary(compressed)).await { + return Err(err.into()); + } + Ok(()) + } + Err(err) => { + error!( + ?err, + "Failed to compress data for WebSocket; this is a bug, please report it" + ); + Err(err.into()) + } + } +} + /// A message that can be received via websocket. #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] pub struct WsMessage { diff --git a/crates/utils/sov-rest-utils/src/ws_tests.rs b/crates/utils/sov-rest-utils/src/ws_tests.rs index c74dc07e6a..44088e60a9 100644 --- a/crates/utils/sov-rest-utils/src/ws_tests.rs +++ b/crates/utils/sov-rest-utils/src/ws_tests.rs @@ -4,6 +4,7 @@ //! 1. Socket errors (feed/flush) cause immediate disconnect //! 2. BroadcastStream lag sends skip notification and resumes (not disconnect) //! 3. Ping/pong keepalive detects dead connections +//! 4. Gzip compression mode batches and compresses messages correctly #[cfg(test)] mod tests { @@ -25,7 +26,10 @@ mod tests { use tokio_stream::wrappers::BroadcastStream; use crate::errors::ReportableWsError; - use crate::serve_generic_ws_subscription; + use crate::{ + serve_generic_ws_subscription, serve_generic_ws_subscription_with_config, + WsSubscriptionConfig, + }; /// Error type matching the real SubscriptionStreamError pattern. #[derive(Debug, Clone)] @@ -201,12 +205,43 @@ mod tests { }) } - /// Starts a test server and returns the address. + /// Handler that uses compression for WebSocket messages. + async fn compressed_handler( + ws: WebSocketUpgrade, + State(state): State, + ) -> impl IntoResponse { + ws.on_upgrade(move |socket| async move { + let (stream, tx) = broadcast_message_stream(100, state.messages_yielded.clone()); + + let messages_produced = state.messages_produced.clone(); + let producer = tokio::spawn(async move { + for i in 0..20 { + if tx.send(format!("message-{i}")).is_err() { + break; + } + messages_produced.fetch_add(1, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(10)).await; + } + }); + + serve_generic_ws_subscription_with_config( + socket, + stream, + state.shutdown_rx.clone(), + WsSubscriptionConfig { compress: true }, + ) + .await; + producer.abort(); + }) + } + + /// Starts a test server with all handlers and returns the address. async fn start_test_server(state: TestState) -> SocketAddr { let app = Router::new() .route("/broadcast", get(broadcast_handler)) .route("/slow", get(slow_broadcast_handler)) .route("/idle", get(idle_handler)) + .route("/compressed", get(compressed_handler)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -498,4 +533,145 @@ mod tests { state.shutdown_tx.send(()).ok(); } + + // ========================================================================= + // Compression Tests + // ========================================================================= + + /// Test: Compressed messages should be binary frames with gzip magic bytes. + #[tokio::test] + async fn test_compressed_messages_are_binary_gzip() { + use flate2::read::GzDecoder; + use std::io::Read; + + let state = TestState::new(); + let addr = start_test_server(state.clone()).await; + + let (mut ws, _) = tokio_tungstenite::connect_async(format!("ws://{addr}/compressed")) + .await + .unwrap(); + + let mut binary_count = 0; + let mut all_messages: Vec = Vec::new(); + + let result = timeout(Duration::from_secs(5), async { + loop { + match ws.next().await { + Some(Ok(tungstenite::Message::Binary(data))) => { + binary_count += 1; + + // Verify gzip magic bytes + assert!( + data.len() >= 2 && data[0] == 0x1f && data[1] == 0x8b, + "Binary frame should start with gzip magic bytes (0x1f 0x8b), got: {:02x} {:02x}", + data.first().copied().unwrap_or(0), + data.get(1).copied().unwrap_or(0) + ); + + // Decompress + let mut decoder = GzDecoder::new(&data[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + // Parse as JSON array + let batch: Vec = serde_json::from_str(&decompressed).unwrap(); + all_messages.extend(batch); + } + Some(Ok(tungstenite::Message::Text(_))) => { + panic!("Compressed mode should not send text frames"); + } + Some(Ok(tungstenite::Message::Ping(_))) | Some(Ok(tungstenite::Message::Pong(_))) => continue, + Some(Ok(tungstenite::Message::Close(_))) | None | Some(Err(_)) => break, + _ => {} + } + } + }) + .await; + + assert!(result.is_ok(), "Test should complete without timeout"); + assert!(binary_count > 0, "Should have received binary frames"); + assert!(!all_messages.is_empty(), "Should have received messages"); + + // Verify message ordering is preserved + for (i, msg) in all_messages.iter().enumerate() { + assert_eq!(msg, &format!("message-{i}"), "Messages should be in order"); + } + + state.shutdown_tx.send(()).ok(); + } + + /// Test: Default config (no compression) should produce text frames. + #[tokio::test] + async fn test_default_config_produces_text_frames() { + let state = TestState::new(); + let addr = start_test_server(state.clone()).await; + + // Use /slow endpoint which uses default (uncompressed) mode + let (mut ws, _) = tokio_tungstenite::connect_async(format!("ws://{addr}/slow")) + .await + .unwrap(); + + let mut text_count = 0; + + let result = timeout(Duration::from_secs(3), async { + while text_count < 5 { + match ws.next().await { + Some(Ok(tungstenite::Message::Text(t))) => { + text_count += 1; + // Verify it's a single JSON string, not an array + let text = t.to_string(); + assert!( + text.starts_with('"'), + "Uncompressed mode should send individual JSON values, not arrays" + ); + } + Some(Ok(tungstenite::Message::Binary(_))) => { + panic!("Uncompressed mode should not send binary frames for data"); + } + Some(Ok(tungstenite::Message::Ping(_))) + | Some(Ok(tungstenite::Message::Pong(_))) => continue, + Some(Ok(tungstenite::Message::Close(_))) | None | Some(Err(_)) => break, + _ => {} + } + } + }) + .await; + + assert!(result.is_ok(), "Test should complete without timeout"); + assert!(text_count >= 5, "Should have received text frames"); + + ws.close(None).await.ok(); + state.shutdown_tx.send(()).ok(); + } + + /// Test: Gzip roundtrip preserves data correctly. + #[test] + fn test_gzip_roundtrip() { + use flate2::read::GzDecoder; + use flate2::write::GzEncoder; + use flate2::Compression; + use std::io::{Read, Write}; + + let original = vec!["message-0", "message-1", "message-2"]; + let json = serde_json::to_vec(&original).unwrap(); + + // Compress + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&json).unwrap(); + let compressed = encoder.finish().unwrap(); + + // Verify magic bytes + assert_eq!(compressed[0], 0x1f, "First byte should be gzip magic"); + assert_eq!(compressed[1], 0x8b, "Second byte should be gzip magic"); + + // Decompress + let mut decoder = GzDecoder::new(&compressed[..]); + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed).unwrap(); + + // Parse + let parsed: Vec = serde_json::from_slice(&decompressed).unwrap(); + + assert_eq!(parsed, original); + } }