diff --git a/Cargo.lock b/Cargo.lock index 16f4a91a39..1950bb098d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13042,6 +13042,7 @@ name = "walrus-e2e-tests" version = "1.41.0" dependencies = [ "anyhow", + "bytes", "futures", "hex", "indicatif 0.17.11", diff --git a/crates/walrus-e2e-tests/Cargo.toml b/crates/walrus-e2e-tests/Cargo.toml index a626e0034d..27e1d86c55 100644 --- a/crates/walrus-e2e-tests/Cargo.toml +++ b/crates/walrus-e2e-tests/Cargo.toml @@ -11,6 +11,7 @@ anyhow.workspace = true prometheus.workspace = true rand.workspace = true # explicitly import reqwest in test to disable its system proxy cache. It causes indeterminism in simtest. +bytes.workspace = true futures.workspace = true hex.workspace = true indicatif.workspace = true diff --git a/crates/walrus-e2e-tests/tests/test_client.rs b/crates/walrus-e2e-tests/tests/test_client.rs index 4ee840bbd3..37370f3fb6 100644 --- a/crates/walrus-e2e-tests/tests/test_client.rs +++ b/crates/walrus-e2e-tests/tests/test_client.rs @@ -21,12 +21,13 @@ use std::{ time::Duration, }; +use bytes::Bytes; use rand::{Rng, random, seq::SliceRandom, thread_rng}; use reqwest::Url; #[cfg(msim)] use sui_macros::{clear_fail_point, register_fail_point_if}; use sui_types::base_types::{SUI_ADDRESS_LENGTH, SuiAddress}; -use tokio::sync::Mutex; +use tokio::{sync::Mutex, time::Instant}; use tokio_stream::StreamExt; use walrus_core::{ BlobId, @@ -58,6 +59,7 @@ use walrus_sdk::{ client_types::WalrusStoreBlob, quilt_client::QuiltClientConfig, responses::{BlobStoreResult, QuiltStoreResult}, + streaming::start_streaming_blob, upload_relay_client::UploadRelayClient, }, config::ClientConfig, @@ -2990,3 +2992,79 @@ async fn test_byte_range_read_size_too_large() -> TestResult { } Ok(()) } + +/// Tests that streaming a blob returns the correct data. +#[ignore = "ignore E2E tests by default"] +#[walrus_simtest] +async fn test_streaming_blob() -> TestResult { + walrus_test_utils::init_tracing(); + + // Setup test cluster + let (_sui_cluster_handle, _cluster, client, _) = + test_cluster::E2eTestSetupBuilder::new().build().await?; + + // Generate and store a test blob (~100KB to span multiple slivers) + let blob_size = 100_000; + let original_data = walrus_test_utils::random_data(blob_size); + + let store_args = StoreArgs::default_with_epochs(5).with_encoding_type(DEFAULT_ENCODING); + + let store_results = client + .inner + .reserve_and_store_blobs(vec![original_data.clone()], &store_args) + .await?; + + let blob_id = store_results + .into_iter() + .next() + .expect("should have one blob store result") + .blob_id() + .expect("blob ID should be present"); + + // Create a read-only client for streaming (SuiReadClient implements Clone) + let sui_read_client = client.inner.sui_client().read_client().clone(); + let config = client.inner.config().clone(); + let streaming_config = config.streaming_config.clone(); + + let read_client = + WalrusNodeClient::new_read_client_with_refresher(config, sui_read_client).await?; + let arc_client = Arc::new(read_client); + + // Call start_streaming_blob directly + let start = Instant::now(); + let (stream, returned_size) = + start_streaming_blob(arc_client, streaming_config, blob_id).await?; + + // Verify returned blob size matches + assert_eq!( + returned_size, blob_size as u64, + "returned blob size should match" + ); + + // Collect stream chunks + let collected: Vec = stream + .map(|result| result.expect("stream chunk should succeed")) + .collect() + .await; + + tracing::info!( + "Collected {} chunks in {:?}", + collected.len(), + start.elapsed() + ); + // Concatenate all chunks + let streamed_data: Vec = collected.into_iter().flat_map(|b| b.to_vec()).collect(); + + // Verify data matches original + assert_eq!( + streamed_data.len(), + original_data.len(), + "streamed data length should match original" + ); + assert_eq!( + streamed_data, original_data, + "streamed data should match original blob" + ); + + Ok(()) +} diff --git a/crates/walrus-sdk/client_config_example.yaml b/crates/walrus-sdk/client_config_example.yaml index 1c630082c1..9fba950a3f 100644 --- a/crates/walrus-sdk/client_config_example.yaml +++ b/crates/walrus-sdk/client_config_example.yaml @@ -66,3 +66,7 @@ quilt_client_config: byte_range_read_client_config: max_retrieve_slivers_attempts: 2 timeout_secs: 10 +streaming_config: + max_sliver_retry_attempts: 5 + sliver_timeout_secs: 30 + prefetch_count: 4 diff --git a/crates/walrus-sdk/src/client.rs b/crates/walrus-sdk/src/client.rs index 4546116d5a..4c3d237320 100644 --- a/crates/walrus-sdk/src/client.rs +++ b/crates/walrus-sdk/src/client.rs @@ -113,6 +113,7 @@ use crate::{ pub mod byte_range_read_client; pub mod client_types; pub mod communication; +pub mod streaming; pub use communication::NodeCommunicationFactory; pub mod metrics; pub mod quilt_client; diff --git a/crates/walrus-sdk/src/client/streaming.rs b/crates/walrus-sdk/src/client/streaming.rs new file mode 100644 index 0000000000..3a9b1452c9 --- /dev/null +++ b/crates/walrus-sdk/src/client/streaming.rs @@ -0,0 +1,561 @@ +// Copyright (c) Walrus Foundation +// SPDX-License-Identifier: Apache-2.0 + +//! Client for streaming blob data sliver-by-sliver. + +use std::{ + collections::BTreeMap, + num::{NonZeroU16, NonZeroU32}, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, + time::Duration, +}; + +use bytes::Bytes; +use futures::{ + Stream, + StreamExt as _, + future::{AbortHandle, AbortRegistration, Abortable}, +}; +use serde::{Deserialize, Serialize}; +use tokio::sync::{Mutex, Notify}; +use walrus_core::{ + BlobId, + EncodingType, + Epoch, + SliverIndex, + encoding::{EncodingFactory, Primary, SliverData}, + metadata::{BlobMetadataApi as _, VerifiedBlobMetadataWithId}, +}; +use walrus_sui::client::ReadClient; + +use crate::{ + client::WalrusNodeClient, + error::{ClientError, ClientErrorKind, ClientResult}, +}; + +/// Configuration for the StreamingReadClient. +#[serde_with::serde_as] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct StreamingConfig { + /// Maximum number of retry attempts per sliver before aborting. + pub max_sliver_retry_attempts: usize, + /// Timeout duration for individual sliver retrieval. + #[serde_as(as = "serde_with::DurationSeconds")] + #[serde(rename = "sliver_timeout_secs")] + pub sliver_timeout: Duration, + /// Number of slivers to prefetch ahead of the current streaming position. + pub prefetch_count: u16, +} + +impl StreamingConfig { + /// Creates a new StreamingReadClientConfig. + pub fn new( + max_sliver_retry_attempts: usize, + sliver_timeout: Duration, + prefetch_count: u16, + ) -> Self { + Self { + max_sliver_retry_attempts, + sliver_timeout, + prefetch_count, + } + } +} + +impl Default for StreamingConfig { + fn default() -> Self { + Self { + max_sliver_retry_attempts: 5, + sliver_timeout: Duration::from_secs(30), + prefetch_count: 4, + } + } +} + +/// Represents the state of a single sliver in the streaming pipeline. +/// +/// Note: Slivers not yet requested are simply absent from the map rather than +/// having an explicit "pending" state. +#[derive(Debug)] +enum SliverFetchState { + /// Sliver retrieval is in progress. Contains the abort handle to cancel if needed. + InFlight(AbortHandle), + /// Sliver was successfully retrieved. + Ready(SliverData), + /// Sliver retrieval failed after all retries. + Failed(ClientError), +} + +/// Manages the prefetch buffer and ordering of slivers for streaming. +struct SliverPrefetchBuffer { + /// Map of sliver index to its fetch state. + slivers: BTreeMap, + /// The next sliver index to stream to the client. + next_to_stream: u16, + /// Total number of slivers in the blob. + total_slivers: NonZeroU16, + /// Number of slivers to keep in-flight ahead of current position. + prefetch_count: u16, +} + +impl SliverPrefetchBuffer { + /// Creates a new SliverPrefetchBuffer. + fn new(total_slivers: NonZeroU16, prefetch_count: u16) -> Self { + Self { + slivers: BTreeMap::new(), + next_to_stream: 0, + total_slivers, + prefetch_count, + } + } + + /// Returns sliver indices that should be fetched next. + /// + /// Returns indices for slivers that are not yet in the map (i.e., not yet requested). + fn get_indices_to_fetch(&self) -> Vec { + let mut indices = Vec::new(); + let end = std::cmp::min( + self.next_to_stream.saturating_add(self.prefetch_count), + self.total_slivers.get(), + ); + + for idx in self.next_to_stream..end { + if !self.slivers.contains_key(&idx) { + indices.push(SliverIndex::new(idx)); + } + } + indices + } + + /// Marks a sliver as in-flight with its abort handle. + fn set_in_flight(&mut self, index: u16, abort_handle: AbortHandle) { + self.slivers + .insert(index, SliverFetchState::InFlight(abort_handle)); + } + + /// Updates a sliver to ready state with the fetched data. + fn set_ready(&mut self, index: u16, data: SliverData) { + self.slivers.insert(index, SliverFetchState::Ready(data)); + } + + /// Updates a sliver to failed state. + fn set_failed(&mut self, index: u16, error: ClientError) { + self.slivers.insert(index, SliverFetchState::Failed(error)); + } + + /// Attempts to take the next sliver to stream (if ready). + /// Returns None if streaming is complete or next sliver is still in-flight. + /// Returns Some(Ok(sliver)) and advances the next_to_stream if the next sliver is ready. + /// Returns Some(Err(error)) if the next sliver failed. + #[must_use] + fn try_take_next(&mut self) -> Option, ClientError>> { + if self.next_to_stream >= self.total_slivers.get() { + return None; + } + + match self.slivers.remove(&self.next_to_stream) { + Some(SliverFetchState::Ready(data)) => { + self.next_to_stream += 1; + Some(Ok(data)) + } + Some(SliverFetchState::Failed(e)) => Some(Err(e)), + Some(in_flight @ SliverFetchState::InFlight(_)) => { + // Still fetching, put it back + self.slivers.insert(self.next_to_stream, in_flight); + None + } + // Not yet requested - shouldn't happen if prefetching is working correctly. + None => None, + } + } + + /// Returns true if all slivers have been streamed. + fn is_complete(&self) -> bool { + self.next_to_stream >= self.total_slivers.get() + } +} + +impl Drop for SliverPrefetchBuffer { + fn drop(&mut self) { + // Abort all in-flight tasks to avoid unnecessary work. + for (_, state) in self.slivers.iter() { + if let SliverFetchState::InFlight(abort_handle) = state { + abort_handle.abort(); + } + } + } +} + +#[derive(Clone)] +struct StreamingState { + /// Prefetch buffer managing sliver states. + prefetch_buffer: Arc>, + /// Notifier for when slivers become ready. + notify: Arc, + /// Whether the stream has been aborted due to an error. + aborted: Arc, + /// Blob metadata. + metadata: Arc, + /// Certified epoch. + certified_epoch: Epoch, + /// Size of each primary sliver in bytes. + primary_sliver_size: NonZeroU32, + /// Total blob size in bytes. + blob_size: u64, + /// Total slivers in the blob. + total_slivers: NonZeroU16, +} + +impl StreamingState { + fn is_aborted(&self) -> bool { + self.aborted.load(Ordering::Relaxed) + } + + async fn is_complete(&self) -> bool { + self.prefetch_buffer.lock().await.is_complete() + } + + fn set_aborted(&self) { + self.aborted.store(true, Ordering::Relaxed) + } + + async fn wait_for_notify(&self) { + self.notify.notified().await + } + + async fn poll_for_next_sliver(&self) -> Option, ClientError>> { + // Check if next sliver is ready + let mut prefetch_buffer = self.prefetch_buffer.lock().await; + prefetch_buffer.try_take_next().map(|result| { + result.map(|sliver| { + extract_sliver_data( + self.primary_sliver_size, + self.blob_size, + self.total_slivers, + sliver, + ) + }) + }) + } +} + +/// Creates a stream that yields blob data chunks in order, sliver by sliver. +/// +/// This method retrieves blob data progressively, prefetching slivers ahead +/// of the current streaming position to minimize latency. Each chunk yielded +/// corresponds to one sliver's worth of data (except the last sliver which +/// may be trimmed to the actual blob size). +/// +/// Returns an error for various reasons, including: +/// - The blob is blocked +/// - The blob doesn't exist +/// - Metadata retrieval fails +/// - The encoding type is not supported +/// +/// Returns the stream and the total blob size in bytes (for progress tracking). +pub async fn start_streaming_blob( + client: Arc>, + config: StreamingConfig, + blob_id: BlobId, +) -> ClientResult<(impl Stream> + Send, u64)> { + tracing::debug!(%blob_id, "starting to stream blob"); + + let (certified_epoch, _) = client + .get_blob_status_and_certified_epoch(&blob_id, None) + .await?; + + let metadata = client.retrieve_metadata(certified_epoch, &blob_id).await?; + + if metadata.metadata().encoding_type() != EncodingType::RS2 { + return Err(ClientError::from(ClientErrorKind::Other( + format!( + "streaming read client only supports RS2 encoding, got {}", + metadata.metadata().encoding_type() + ) + .into(), + ))); + } + + let blob_size = metadata.metadata().unencoded_length(); + + // Handle zero-size blobs early - return empty stream + if blob_size == 0 { + tracing::debug!(%blob_id, "zero-size blob, returning empty stream"); + return Ok((futures::stream::empty().left_stream(), 0)); + } + + let (primary_sliver_size, primary_sliver_count) = + get_primary_sliver_size_and_count(&client, blob_size, &metadata)?; + + tracing::debug!( + %blob_id, + blob_size, + primary_sliver_size, + primary_sliver_count, + "blob metadata retrieved for streaming" + ); + + // Create the streaming state + let state = StreamingState { + prefetch_buffer: Arc::new(Mutex::new(SliverPrefetchBuffer::new( + primary_sliver_count, + config.prefetch_count, + ))), + notify: Arc::new(Notify::new()), + aborted: Arc::new(AtomicBool::from(false)), + metadata: Arc::new(metadata), + certified_epoch, + primary_sliver_size, + blob_size, + total_slivers: primary_sliver_count, + }; + + Ok(( + create_sliver_stream(client, config, state).right_stream(), + blob_size, + )) +} + +/// Gets the size and count of primary slivers for the given blob size and metadata. +fn get_primary_sliver_size_and_count( + client: &WalrusNodeClient, + blob_size: u64, + metadata: &VerifiedBlobMetadataWithId, +) -> ClientResult<(NonZeroU32, NonZeroU16)> { + let encoding_config = client + .encoding_config() + .get_for_type(metadata.metadata().encoding_type()); + let primary_sliver_size = encoding_config + .sliver_size_for_blob::(blob_size) + .map_err(|_| { + ClientError::from(ClientErrorKind::Other( + "blob too large to determine sliver size".into(), + )) + })?; + let primary_sliver_count = encoding_config.n_systematic_slivers::(); + Ok((primary_sliver_size, primary_sliver_count)) +} + +/// Creates the async stream that prefetches and yields sliver data. +fn create_sliver_stream( + client: Arc>, + config: StreamingConfig, + state: StreamingState, +) -> impl Stream> + Send { + futures::stream::unfold( + (state, client, config), + |(state, client, config)| async move { + loop { + if state.is_aborted() || state.is_complete().await { + // Checking is_aborted here is just to ensure that this stream ends if being + // polled after an abort. + return None; + } + + spawn_prefetch_tasks(client.clone(), config.clone(), state.clone()).await; + + match state.poll_for_next_sliver().await { + Some(Ok(data)) => { + return Some((Ok(Bytes::from(data)), (state, client, config))); + } + Some(Err(e)) => { + state.set_aborted(); + return Some((Err(e), (state, client, config))); + } + None => state.wait_for_notify().await, + } + } + }, + ) +} + +/// Retrieves a single sliver. +async fn retrieve_single_sliver_with_retry( + client: &WalrusNodeClient, + config: &StreamingConfig, + metadata: Arc, + sliver_index: SliverIndex, + certified_epoch: Epoch, +) -> ClientResult> { + let slivers: Vec> = client + .retrieve_slivers_retry_committees::( + metadata.as_ref(), + &[sliver_index], + certified_epoch, + config.max_sliver_retry_attempts, // Single attempt per round + config.sliver_timeout, + ) + .await?; + slivers.into_iter().next().ok_or_else(|| { + ClientError::from(ClientErrorKind::Other( + format!( + "unexpected empty sliver result for sliver {}", + sliver_index.get() + ) + .into(), + )) + }) +} + +/// Spawns prefetch tasks for slivers that need to be fetched. +async fn spawn_prefetch_tasks( + client: Arc>, + config: StreamingConfig, + state: StreamingState, +) { + let tasks_to_spawn: Vec<(SliverIndex, AbortRegistration)> = { + let mut prefetch_buffer = state.prefetch_buffer.lock().await; + let indices = prefetch_buffer.get_indices_to_fetch(); + + let tasks: Vec<_> = indices + .into_iter() + .map(|index| { + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + prefetch_buffer.set_in_flight(index.get(), abort_handle); + (index, abort_registration) + }) + .collect(); + + tasks + }; + + // Spawn all tasks outside the lock. + for (index, abort_registration) in tasks_to_spawn { + let client = client.clone(); + let config = config.clone(); + let state = state.clone(); + + tokio::spawn(Abortable::new( + async move { + let result = retrieve_single_sliver_with_retry( + &client, + &config, + state.metadata.clone(), + index, + state.certified_epoch, + ) + .await; + + { + let mut prefetch_buffer = state.prefetch_buffer.lock().await; + match result { + Ok(sliver) => prefetch_buffer.set_ready(index.get(), sliver), + Err(e) => prefetch_buffer.set_failed(index.get(), e), + } + } + state.notify.clone().notify_one(); + }, + abort_registration, + )); + } +} + +/// Extracts the data from a sliver, handling the last sliver specially. +fn extract_sliver_data( + primary_sliver_size: NonZeroU32, + blob_size: u64, + total_slivers: NonZeroU16, + sliver: SliverData, +) -> Vec { + let mut sliver_data = sliver.symbols.into_vec(); + let sliver_index = u64::from(sliver.index.get()); + let total_slivers = u64::from(u16::from(total_slivers)); + + let is_last_sliver = total_slivers > 0 && sliver_index == total_slivers.saturating_sub(1); + + if is_last_sliver && blob_size > 0 { + // Calculate expected data in last sliver + let full_slivers_size = (total_slivers - 1) * u64::from(primary_sliver_size.get()); + let last_sliver_data_size = + usize::try_from(blob_size - full_slivers_size).expect("should fit in u64"); + + // Trim padding from last sliver + if last_sliver_data_size < sliver_data.len() { + sliver_data.truncate(last_sliver_data_size) + } + } + sliver_data +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sliver_prefetch_buffer_basic() { + let mut buffer = SliverPrefetchBuffer::new(NonZeroU16::try_from(10).unwrap(), 4); + + // Initially should want to fetch first 4 slivers + let indices = buffer.get_indices_to_fetch(); + assert_eq!(indices.len(), 4); + assert_eq!(indices[0].get(), 0); + assert_eq!(indices[3].get(), 3); + + // Mark some as in-flight + let (abort_handle, _) = AbortHandle::new_pair(); + buffer.set_in_flight(0, abort_handle); + + // Should not return in-flight slivers + let indices = buffer.get_indices_to_fetch(); + assert_eq!(indices.len(), 3); + assert_eq!(indices[0].get(), 1); + } + + #[test] + fn test_sliver_prefetch_buffer_advance() { + let mut buffer = SliverPrefetchBuffer::new(NonZeroU16::try_from(3).unwrap(), 2); + + // Set up first sliver as ready + buffer.set_ready( + 0, + SliverData::new( + vec![1, 2, 3], + std::num::NonZeroU16::new(1).unwrap(), + SliverIndex::new(0), + ), + ); + + // Take it + let result = buffer.try_take_next(); + assert!(result.is_some()); + assert!(result.unwrap().is_ok()); + + // Advance + assert_eq!(buffer.next_to_stream, 1); + assert!(!buffer.is_complete()); + + // Advance through remaining + buffer.set_ready( + 1, + SliverData::new( + vec![4, 5, 6], + std::num::NonZeroU16::new(1).unwrap(), + SliverIndex::new(1), + ), + ); + let _ = buffer.try_take_next(); + + buffer.set_ready( + 2, + SliverData::new( + vec![7, 8, 9], + std::num::NonZeroU16::new(1).unwrap(), + SliverIndex::new(2), + ), + ); + let _ = buffer.try_take_next(); + + assert!(buffer.is_complete()); + } + + #[test] + fn test_config_defaults() { + let config = StreamingConfig::default(); + assert_eq!(config.max_sliver_retry_attempts, 5); + assert_eq!(config.sliver_timeout, Duration::from_secs(30)); + assert_eq!(config.prefetch_count, 4); + } +} diff --git a/crates/walrus-sdk/src/config.rs b/crates/walrus-sdk/src/config.rs index d80e50682e..a56ace03e6 100644 --- a/crates/walrus-sdk/src/config.rs +++ b/crates/walrus-sdk/src/config.rs @@ -38,6 +38,7 @@ use crate::client::{ byte_range_read_client::ByteRangeReadClientConfig, quilt_client::QuiltClientConfig, refresh::{CommitteesRefresher, CommitteesRefresherHandle}, + streaming::StreamingConfig, }; mod committees_refresh_config; @@ -124,6 +125,9 @@ pub struct ClientConfig { /// The configuration of the ByteRangeReadClient. #[serde(default)] pub byte_range_read_client_config: ByteRangeReadClientConfig, + /// The configuration of the StreamingReadClient. + #[serde(default)] + pub streaming_config: StreamingConfig, } impl ClientConfig { @@ -139,6 +143,7 @@ impl ClientConfig { refresh_config: Default::default(), quilt_client_config: Default::default(), byte_range_read_client_config: Default::default(), + streaming_config: Default::default(), } } @@ -354,6 +359,7 @@ mod tests { refresh_config: Default::default(), quilt_client_config: Default::default(), byte_range_read_client_config: Default::default(), + streaming_config: Default::default(), }; walrus_test_utils::overwrite_file_and_fail_if_not_equal( diff --git a/crates/walrus-service/src/client/daemon.rs b/crates/walrus-service/src/client/daemon.rs index 8019dcf945..5a03299aae 100644 --- a/crates/walrus-service/src/client/daemon.rs +++ b/crates/walrus-service/src/client/daemon.rs @@ -3,12 +3,12 @@ //! A client daemon who serves a set of simple HTTP endpoints to store, encode, or read blobs. -use std::{collections::HashSet, fmt::Debug, net::SocketAddr, str::FromStr, sync::Arc}; +use std::{collections::HashSet, fmt::Debug, net::SocketAddr, pin::Pin, str::FromStr, sync::Arc}; use axum::{ BoxError, Router, - body::HttpBody, + body::{Bytes, HttpBody}, error_handling::HandleErrorLayer, extract::{DefaultBodyLimit, Query, Request, State}, http::HeaderName, @@ -20,6 +20,7 @@ use axum_extra::{ TypedHeader, headers::{Authorization, authorization::Bearer}, }; +use futures::Stream; use openapi::{AggregatorApiDoc, DaemonApiDoc, PublisherApiDoc}; use reqwest::StatusCode; use routes::{ @@ -28,6 +29,7 @@ use routes::{ BLOB_GET_ENDPOINT, BLOB_OBJECT_GET_ENDPOINT, BLOB_PUT_ENDPOINT, + BLOB_STREAM_ENDPOINT, LIST_PATCHES_IN_QUILT_ENDPOINT, QUILT_PATCH_BY_ID_GET_ENDPOINT, QUILT_PATCH_BY_IDENTIFIER_GET_ENDPOINT, @@ -65,6 +67,7 @@ use walrus_sdk::{ WalrusNodeClient, byte_range_read_client::ReadByteRangeResult, responses::{BlobStoreResult, QuiltStoreResult}, + streaming::start_streaming_blob, }, error::{ClientError, ClientResult}, store_optimizations::StoreOptimizations, @@ -90,6 +93,9 @@ pub(crate) use cache::{CacheConfig, CacheHandle}; mod openapi; mod routes; +/// Type alias for a boxed stream of blob data chunks. +pub type BlobStream = Pin> + Send>>; + pub trait WalrusReadClient { /// Reads a blob from Walrus. fn read_blob( @@ -133,6 +139,18 @@ pub trait WalrusReadClient { &self, _quilt_id: &BlobId, ) -> impl std::future::Future>> + Send; + + /// Streams a blob sliver-by-sliver. + /// + /// Returns a stream that yields blob data chunks in order, with prefetching + /// and aggressive retry logic for improved performance on large blobs. + /// + /// Takes `Arc` to allow spawning background tasks for prefetching. + /// Returns the stream and the total blob size in bytes (for progress tracking). + fn stream_blob( + self: Arc, + blob_id: &BlobId, + ) -> impl std::future::Future> + Send; } /// Trait representing a client that can write blobs to Walrus. @@ -170,7 +188,7 @@ pub trait WalrusWriteClient: WalrusReadClient { fn default_post_store_action(&self) -> PostStoreAction; } -impl WalrusReadClient for WalrusNodeClient { +impl WalrusReadClient for WalrusNodeClient { async fn read_blob( &self, blob_id: &BlobId, @@ -250,6 +268,13 @@ impl WalrusReadClient for WalrusNodeClient { Ok(patches) } + + async fn stream_blob(self: Arc, blob_id: &BlobId) -> ClientResult<(BlobStream, u64)> { + let config = self.config().streaming_config.clone(); + let (stream, blob_size) = start_streaming_blob(self, config, *blob_id).await?; + + Ok((Box::pin(stream), blob_size)) + } } impl WalrusWriteClient for WalrusNodeClient { @@ -435,6 +460,10 @@ impl ClientDaemon { BLOB_BYTE_RANGE_GET_ENDPOINT, get(routes::get_blob_byte_range).route_layer(aggregator_layers.clone()), ) + .route( + BLOB_STREAM_ENDPOINT, + get(routes::stream_blob).route_layer(aggregator_layers.clone()), + ) .route( BLOB_CONCAT_ENDPOINT, get(routes::get_blobs_concat) @@ -810,6 +839,13 @@ mod tests { ) -> ClientResult> { unimplemented!("not needed for rate limit tests") } + + async fn stream_blob( + self: Arc, + _blob_id: &BlobId, + ) -> ClientResult<(BlobStream, u64)> { + unimplemented!("not needed for rate limit tests") + } } #[tokio::test] diff --git a/crates/walrus-service/src/client/daemon/routes.rs b/crates/walrus-service/src/client/daemon/routes.rs index 60eb946919..e99284e26a 100644 --- a/crates/walrus-service/src/client/daemon/routes.rs +++ b/crates/walrus-service/src/client/daemon/routes.rs @@ -29,7 +29,15 @@ use http_range_header::{EndPosition, StartPosition}; use jsonwebtoken::{DecodingKey, Validation}; use reqwest::{ Method, - header::{ACCEPT, CACHE_CONTROL, CONTENT_TYPE, ETAG, RANGE, X_CONTENT_TYPE_OPTIONS}, + header::{ + ACCEPT, + CACHE_CONTROL, + CONTENT_LENGTH, + CONTENT_TYPE, + ETAG, + RANGE, + X_CONTENT_TYPE_OPTIONS, + }, }; use serde::{Deserialize, Serialize}; use serde_with::{DisplayFromStr, serde_as}; @@ -97,6 +105,8 @@ pub const QUILT_PATCH_BY_IDENTIFIER_GET_ENDPOINT: &str = pub const LIST_PATCHES_IN_QUILT_ENDPOINT: &str = "/v1/quilts/{quilt_id}/patches"; /// The path to read a byte range from a blob. pub const BLOB_BYTE_RANGE_GET_ENDPOINT: &str = "/v1/blobs/{blob_id}/byte-range"; +/// The path to stream a blob sliver-by-sliver. +pub const BLOB_STREAM_ENDPOINT: &str = "/v1alpha/blobs/{blob_id}/stream"; /// Custom header for quilt patch identifier. const X_QUILT_PATCH_IDENTIFIER: &str = "X-Quilt-Patch-Identifier"; @@ -359,6 +369,7 @@ pub(super) async fn get_blob( request_method, &request_headers, &blob_id.to_string(), + None, headers, ); response @@ -368,6 +379,7 @@ fn populate_response_headers_from_request( request_method: Method, request_headers: &HeaderMap, etag: &str, + blob_size: Option, headers: &mut HeaderMap, ) { // Prevent the browser from trying to guess the MIME type to avoid dangerous inferences. @@ -388,6 +400,12 @@ fn populate_response_headers_from_request( HeaderValue::from_str(etag) .expect("the blob ID string only contains visible ASCII characters"), ); + + // Certain codepaths may provide the content-length (e.g., streaming reads). + if let Some(content_length) = blob_size.map(HeaderValue::from) { + headers.insert(CONTENT_LENGTH, content_length); + } + // Mirror the content type in various ways. if let Some(accept) = request_headers.get(ACCEPT) && !accept.as_bytes().contains(&b'*') @@ -647,6 +665,7 @@ pub(super) async fn get_blob_byte_range( request_method, &request_headers, &etag, + None, headers, ); response @@ -661,6 +680,107 @@ pub(super) async fn get_blob_byte_range( } } +/// Errors that can occur when streaming a blob. +#[derive(Debug, thiserror::Error, RestApiError)] +#[rest_api_error(domain = ERROR_DOMAIN)] +pub(crate) enum StreamBlobError { + /// The requested blob has not yet been stored on Walrus. + #[error("the requested blob ID does not exist on Walrus")] + #[rest_api_error(reason = "BLOB_NOT_FOUND", status = ApiStatusCode::NotFound)] + BlobNotFound, + + /// The blob cannot be returned as it has been blocked. + #[error("the requested blob is blocked")] + #[rest_api_error(reason = "FORBIDDEN_BLOB", status = ApiStatusCode::UnavailableForLegalReasons)] + Blocked, + + /// Failed to retrieve one or more slivers after retries. + #[error("failed to retrieve sliver after retries: {message}")] + #[rest_api_error(reason = "SLIVER_RETRIEVAL_FAILED", status = ApiStatusCode::Internal)] + SliverRetrievalFailed { message: String }, + + /// The blob size exceeds the maximum allowed size. + #[error("the blob size exceeds the maximum allowed size: {0}")] + #[rest_api_error(reason = "BLOB_TOO_LARGE", status = ApiStatusCode::SizeExceeded)] + BlobTooLarge(u64), +} + +impl From for StreamBlobError { + fn from(error: ClientError) -> Self { + match error.kind() { + ClientErrorKind::BlobIdDoesNotExist => Self::BlobNotFound, + ClientErrorKind::BlobIdBlocked(_) => Self::Blocked, + ClientErrorKind::BlobTooLarge(max_blob_size) => Self::BlobTooLarge(*max_blob_size), + _ => Self::SliverRetrievalFailed { + message: error.to_string(), + }, + } + } +} + +/// Stream a Walrus blob sliver-by-sliver. +/// +/// Reconstructs and streams the blob identified by the provided blob ID from Walrus. +/// Data is streamed progressively as slivers are retrieved, reducing time-to-first-byte +/// for large blobs. +/// +/// This endpoint uses aggressive retry logic with longer timeouts for each sliver, +/// and prefetches slivers ahead of the current streaming position to minimize wait time. +/// +/// If a sliver cannot be retrieved after multiple retries, the stream will abort with an error. +/// Clients should be prepared to handle partial data in case of failures. +#[tracing::instrument(level = Level::ERROR, skip_all, fields(%blob_id))] +#[utoipa::path( + get, + path = BLOB_STREAM_ENDPOINT, + params(("blob_id" = BlobId,)), + responses( + (status = 200, description = "Blob streaming started successfully", body = [u8]), + StreamBlobError, + ), +)] +pub(super) async fn stream_blob( + request_method: Method, + request_headers: HeaderMap, + State(client): State>, + Path(BlobIdString(blob_id)): Path, +) -> Response { + tracing::debug!("starting to stream blob"); + + match client.stream_blob(&blob_id).await { + Ok((stream, blob_size)) => { + use futures::StreamExt; + // Wrap stream to convert ClientError to axum-compatible errors + let byte_stream = StreamExt::map(stream, |result: Result| { + result.map_err(|e: ClientError| { + tracing::error!(error = ?e, "error during blob streaming"); + std::io::Error::other(e.to_string()) + }) + }); + + let mut response = (StatusCode::OK, Body::from_stream(byte_stream)).into_response(); + let headers = response.headers_mut(); + populate_response_headers_from_request( + request_method, + &request_headers, + &blob_id.to_string(), + Some(blob_size), + headers, + ); + // Add streaming-specific headers + headers.insert( + HeaderName::from_static("x-walrus-streaming"), + HeaderValue::from_static("true"), + ); + response + } + Err(error) => { + tracing::debug!(?error, "failed to start blob stream"); + StreamBlobError::from(error).to_response() + } + } +} + #[derive(Debug, thiserror::Error, RestApiError)] #[rest_api_error(domain = ERROR_DOMAIN)] pub(crate) enum GetBlobError { @@ -836,7 +956,7 @@ async fn concat_blobs_impl( let mut response = (StatusCode::OK, Body::from_stream(stream)).into_response(); let headers = response.headers_mut(); - populate_response_headers_from_request(request_method, &request_headers, &etag, headers); + populate_response_headers_from_request(request_method, &request_headers, &etag, None, headers); if let Some(attribute) = first_attribute { populate_response_headers_from_attributes( @@ -1198,6 +1318,7 @@ fn build_quilt_patch_response( request_method, request_headers, etag, + None, response.headers_mut(), ); populate_response_headers_from_attributes( diff --git a/crates/walrus-service/src/client/multiplexer.rs b/crates/walrus-service/src/client/multiplexer.rs index 468379a33d..f2ab8a8e54 100644 --- a/crates/walrus-service/src/client/multiplexer.rs +++ b/crates/walrus-service/src/client/multiplexer.rs @@ -32,7 +32,7 @@ use walrus_sdk::{ responses::{BlobStoreResult, QuiltStoreResult}, }, config::ClientConfig, - error::ClientResult, + error::{ClientError, ClientErrorKind, ClientResult}, store_optimizations::StoreOptimizations, }; use walrus_sui::{ @@ -214,6 +214,25 @@ impl WalrusReadClient for ClientMultiplexer { async fn list_patches_in_quilt(&self, quilt_id: &BlobId) -> ClientResult> { self.read_client.list_patches_in_quilt(quilt_id).await } + + /// Streaming is not supported through the ClientMultiplexer. + /// + /// The multiplexer manages multiple write clients for parallelism, but the read client + /// is stored directly (not in an Arc) which is incompatible with the streaming API that + /// requires `Arc` for background prefetch tasks. + /// + /// For streaming blob downloads, use the aggregator endpoint directly: + /// `GET /v1alpha/blobs/{blob_id}/stream` + async fn stream_blob( + self: Arc, + _blob_id: &BlobId, + ) -> ClientResult<(super::daemon::BlobStream, u64)> { + Err(ClientError::from(ClientErrorKind::Other( + "streaming not supported through ClientMultiplexer; \ + use the aggregator /v1alpha/blobs/{blob_id}/stream endpoint instead" + .into(), + ))) + } } impl WalrusWriteClient for ClientMultiplexer { diff --git a/crates/walrus-service/src/test_utils.rs b/crates/walrus-service/src/test_utils.rs index 38920b3bd5..256e9ecdc2 100644 --- a/crates/walrus-service/src/test_utils.rs +++ b/crates/walrus-service/src/test_utils.rs @@ -2998,6 +2998,7 @@ pub mod test_cluster { refresh_config: Default::default(), quilt_client_config: Default::default(), byte_range_read_client_config: Default::default(), + streaming_config: Default::default(), }; let client = admin_contract_client diff --git a/crates/walrus-service/src/testbed.rs b/crates/walrus-service/src/testbed.rs index cfd44e0f7d..452ad52e29 100644 --- a/crates/walrus-service/src/testbed.rs +++ b/crates/walrus-service/src/testbed.rs @@ -529,6 +529,7 @@ pub async fn create_client_config( refresh_config: Default::default(), quilt_client_config: Default::default(), byte_range_read_client_config: Default::default(), + streaming_config: Default::default(), }; Ok(client_config)