diff --git a/Cargo.lock b/Cargo.lock index e4dbf9452..10ed5adf3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11546,9 +11546,14 @@ dependencies = [ "anyhow", "bytes 1.8.0", "ed25519-dalek 2.1.1", + "futures", + "futures-core", + "futures-util", "http-body-util", "hyper-util", "movement-da-sequencer-proto", + "tokio", + "tokio-stream", "tonic 0.12.3", "tonic-web", "tower 0.5.1", diff --git a/protocol-units/da-sequencer/client/Cargo.toml b/protocol-units/da-sequencer/client/Cargo.toml index 4ffe8f2a0..217f745cc 100644 --- a/protocol-units/da-sequencer/client/Cargo.toml +++ b/protocol-units/da-sequencer/client/Cargo.toml @@ -11,7 +11,10 @@ rust-version.workspace = true [dependencies] ed25519-dalek = { workspace = true } -movement-da-sequencer-proto = { workspace = true, features = ["client"] } +futures-util = "0.3" +futures-core = "0.3" +futures = { workspace = true } +movement-da-sequencer-proto = { workspace = true, features = ["client", "server"] } tonic = { workspace = true, features = ["tls", "tls-webpki-roots"]} tonic-web = { workspace = true } hyper-util = { workspace = true } @@ -20,6 +23,8 @@ http-body-util = { workspace = true } bytes = { workspace = true } anyhow = { workspace = true } tracing = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } [lints] #workspace = true diff --git a/protocol-units/da-sequencer/client/src/lib.rs b/protocol-units/da-sequencer/client/src/lib.rs index a80bb03a7..0711820f8 100644 --- a/protocol-units/da-sequencer/client/src/lib.rs +++ b/protocol-units/da-sequencer/client/src/lib.rs @@ -1,8 +1,17 @@ -use ed25519_dalek::Signer; -use ed25519_dalek::{Signature, SigningKey}; +use anyhow::Result; +use ed25519_dalek::{Signature, Signer, SigningKey}; +use futures_core::Stream; +use futures_util::stream::unfold; use movement_da_sequencer_proto::da_sequencer_node_service_client::DaSequencerNodeServiceClient; +use movement_da_sequencer_proto::{ + BatchWriteRequest, BatchWriteResponse, StreamReadFromHeightRequest, + StreamReadFromHeightResponse, +}; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; use std::time::Duration; use tonic::transport::{Channel, ClientTlsConfig}; +use tonic::{Status, Streaming}; /// A wrapping MovementDaLightNodeClients over complex types. /// @@ -10,58 +19,41 @@ use tonic::transport::{Channel, ClientTlsConfig}; /// This simplifies client construction and usage. #[derive(Debug, Clone)] pub struct DaSequencerClient { - client: DaSequencerNodeServiceClient, + client: DaSequencerNodeServiceClient, + connection_string: String, + last_received_height: Arc>>, } impl DaSequencerClient { /// Creates an http2 connection to the Da Sequencer node service. - pub async fn try_connect(connection_string: &str) -> Result { + pub async fn try_connect(connection_string: &str) -> Result { for _ in 0..5 { - match DaSequencerClient::connect(connection_string).await { - Ok(client) => return Ok(DaSequencerClient { client }), + match Self::connect(connection_string).await { + Ok(client) => { + return Ok(Self { + client, + connection_string: connection_string.to_string(), + last_received_height: Arc::new(Mutex::new(None)), + }); + } Err(err) => { tracing::warn!( - "DA sequencer Http2 connection failed: {}. Retrying in 10s...", + "DA sequencer HTTP/2 connection failed: {}. Retrying in 2s...", err ); - std::thread::sleep(std::time::Duration::from_secs(10)); + tokio::time::sleep(Duration::from_secs(2)).await; } } } - return Err(anyhow::anyhow!( - "Error DA Sequencer Http2 connection failed more than 5 time aborting.", - )); - } - /// Stream reads from a given height. - pub async fn stream_read_from_height( - &mut self, - request: movement_da_sequencer_proto::StreamReadFromHeightRequest, - ) -> Result< - tonic::Streaming, - tonic::Status, - > { - let response = self.client.stream_read_from_height(request).await?; - Ok(response.into_inner()) - } - - /// Writes a batch of transactions to the light node - pub async fn batch_write( - &mut self, - request: movement_da_sequencer_proto::BatchWriteRequest, - ) -> Result { - let response = self.client.batch_write(request).await?; - Ok(response.into_inner()) + Err(anyhow::anyhow!("Connection failed more than 5 times")) } - /// Connects to a da sequencer node service using the given connection string. - async fn connect( - connection_string: &str, - ) -> Result, anyhow::Error> { - tracing::info!("Grpc client connect using :{connection_string}"); + /// Opens a raw tonic connection to the DA service. + async fn connect(connection_string: &str) -> Result> { + tracing::info!("Grpc client connect using: {}", connection_string); let endpoint = Channel::from_shared(connection_string.to_string())?; - // Dynamically configure TLS based on the scheme (http or https) let endpoint = if connection_string.starts_with("https://") { endpoint .tls_config(ClientTlsConfig::new().with_enabled_roots())? @@ -71,9 +63,108 @@ impl DaSequencerClient { }; let channel = endpoint.connect().await?; - let client = DaSequencerNodeServiceClient::new(channel); + Ok(DaSequencerNodeServiceClient::new(channel)) + } + + /// Reconnects the internal gRPC client. + async fn reconnect(&mut self) -> Result<()> { + tracing::info!("Reconnecting to {}", self.connection_string); + let client = Self::connect(&self.connection_string).await?; + self.client = client; + Ok(()) + } + + /// Streams blocks starting from a height, with reconnect and resume. + pub async fn stream_read_from_height( + &mut self, + start_request: StreamReadFromHeightRequest, + ) -> Result< + Pin> + Send>>, + Status, + > { + let height = { + let last = self.last_received_height.lock().unwrap(); + if let Some(last_h) = *last { + tracing::info!("Resuming stream from height: {}", last_h + 1); + last_h + 1 + } else { + tracing::info!("Starting stream from requested height: {}", start_request.height); + start_request.height + } + }; + + match self + .client + .stream_read_from_height(StreamReadFromHeightRequest { height }) + .await + { + Ok(response) => Ok(Self::wrap_stream_with_height_tracking( + response.into_inner(), + Arc::clone(&self.last_received_height), + )), + Err(e) => { + tracing::warn!("stream_read_from_height failed, trying reconnect: {e}"); + self.reconnect() + .await + .map_err(|e| Status::unavailable(format!("Reconnect failed: {e}")))?; + + let response = self + .client + .stream_read_from_height(StreamReadFromHeightRequest { height }) + .await?; + + Ok(Self::wrap_stream_with_height_tracking( + response.into_inner(), + Arc::clone(&self.last_received_height), + )) + } + } + } + + /// Wraps a stream to track and store the last received height. + fn wrap_stream_with_height_tracking( + stream: Streaming, + last_received_height: Arc>>, + ) -> Pin> + Send>> { + let wrapped = unfold((stream, last_received_height), |(mut s, tracker)| async move { + match s.message().await { + Ok(Some(msg)) => { + if let Some(ref blob) = msg.response { + if let Some(height) = blob.blob_type.as_ref().and_then(|b| match b { + movement_da_sequencer_proto::blob_response::BlobType::Blockv1( + inner, + ) => Some(inner.height), + _ => None, + }) { + *tracker.lock().unwrap() = Some(height); + } + } + Some((Ok(msg), (s, tracker))) + } + Ok(None) => None, + Err(e) => Some((Err(e), (s, tracker))), + } + }); - Ok(client) + Box::pin(wrapped) + } + + /// Sends a batch write request with reconnect on failure. + pub async fn batch_write( + &mut self, + request: BatchWriteRequest, + ) -> Result { + match self.client.batch_write(request.clone()).await { + Ok(response) => Ok(response.into_inner()), + Err(_) => { + self.reconnect() + .await + .map_err(|e| Status::unavailable(format!("Reconnect failed: {}", e)))?; + + let response = self.client.batch_write(request).await?; + Ok(response.into_inner()) + } + } } } diff --git a/protocol-units/da-sequencer/client/tests/connection.rs b/protocol-units/da-sequencer/client/tests/connection.rs new file mode 100644 index 000000000..6ca31298d --- /dev/null +++ b/protocol-units/da-sequencer/client/tests/connection.rs @@ -0,0 +1,224 @@ +use futures_util::StreamExt; +use movement_da_sequencer_client::DaSequencerClient; +use movement_da_sequencer_proto::blob_response::BlobType; +use movement_da_sequencer_proto::da_sequencer_node_service_server::{ + DaSequencerNodeService, DaSequencerNodeServiceServer, +}; +use movement_da_sequencer_proto::{ + BatchWriteRequest, BatchWriteResponse, BlobResponse, Blockv1, ReadAtHeightRequest, + ReadAtHeightResponse, StreamReadFromHeightRequest, StreamReadFromHeightResponse, +}; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream}; +use tonic::transport::Server; +use tonic::{Request, Response, Status}; + +struct MockService; + +#[tonic::async_trait] +impl DaSequencerNodeService for MockService { + type StreamReadFromHeightStream = ReceiverStream>; + + async fn stream_read_from_height( + &self, + _request: Request, + ) -> Result, Status> { + let (tx, rx) = mpsc::channel(1); + + let blob = BlobResponse { + blob_type: Some(BlobType::Blockv1(Blockv1 { + blobckid: vec![], + data: vec![], + height: 0, + })), + }; + + let _ = tx.send(Ok(StreamReadFromHeightResponse { response: Some(blob) })).await; + + Ok(Response::new(ReceiverStream::new(rx))) + } + + async fn batch_write( + &self, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(BatchWriteResponse { answer: true })) + } + + async fn read_at_height( + &self, + _request: Request, + ) -> Result, Status> { + let blob = BlobResponse { + blob_type: Some(BlobType::Blockv1(Blockv1 { + blobckid: vec![], + data: vec![], + height: 0, + })), + }; + + Ok(Response::new(ReadAtHeightResponse { response: Some(blob) })) + } +} + +#[tokio::test] +async fn test_client_reconnect_if_connection_fails() { + // Bind to an available port but do not start the server yet + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let url = format!("http://{}", addr); + + // Begin trying to connect to the DA server before it's running + let client_task = tokio::spawn(async move { DaSequencerClient::try_connect(&url).await }); + + // Simulate the server being offline briefly + tokio::time::sleep(Duration::from_secs(2)).await; + + // Now start the server + tokio::spawn(async move { + let service = DaSequencerNodeServiceServer::new(MockService); + Server::builder() + .add_service(service) + .serve_with_incoming(TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + // The client should eventually succeed + let result = client_task.await.unwrap(); + assert!(result.is_ok(), "Expected client to reconnect after retries, but it failed"); +} + +#[tokio::test] +async fn test_stream_reconnects_and_resumes_from_correct_height() { + use std::sync::{Arc, Mutex}; + use tokio::sync::{mpsc, oneshot}; + + // Shared state for sending blocks across server restarts + let _blocks_sent = Arc::new(Mutex::new(vec![ + 0, 1, // first server sends blocks 0 and 1 + 2, 3, // second server sends blocks 2 and 3 + ])); + + // Mock service that streams blocks based on the current `blocks_sent` + struct ReconnectableMock { + heights: Arc>>, + } + + #[tonic::async_trait] + impl DaSequencerNodeService for ReconnectableMock { + type StreamReadFromHeightStream = + ReceiverStream>; + + async fn stream_read_from_height( + &self, + request: Request, + ) -> Result, Status> { + let start_height = request.into_inner().height; + let (tx, rx) = mpsc::channel(10); + + let heights = self.heights.lock().unwrap().clone(); + tokio::spawn(async move { + for h in heights.into_iter().filter(|h| *h >= start_height) { + let blob = BlobResponse { + blob_type: Some(BlobType::Blockv1(Blockv1 { + blobckid: vec![], + data: vec![h as u8], + height: h, + })), + }; + + let msg = StreamReadFromHeightResponse { response: Some(blob) }; + + tx.send(Ok(msg)).await.unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + } + }); + + Ok(Response::new(ReceiverStream::new(rx))) + } + + async fn batch_write( + &self, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(BatchWriteResponse { answer: true })) + } + + async fn read_at_height( + &self, + _request: Request, + ) -> Result, Status> { + unimplemented!() + } + } + + let addr = "127.0.0.1:50055".parse::().unwrap(); + let url = format!("http://{}", addr); + + // First server: send blocks 0 and 1 + let heights_1 = Arc::new(Mutex::new(vec![0, 1])); + let mock_1 = ReconnectableMock { heights: heights_1.clone() }; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + + tokio::spawn(async move { + Server::builder() + .add_service(DaSequencerNodeServiceServer::new(mock_1)) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + shutdown_rx.await.ok(); + }) + .await + .unwrap(); + }); + + // Connect client and start receiving blocks + let mut client = DaSequencerClient::try_connect(&url).await.unwrap(); + + let mut last_height = 0; + let mut stream = client + .stream_read_from_height(StreamReadFromHeightRequest { height: 0 }) + .await + .unwrap(); + + // Receive first two blocks + for _ in 0..2 { + let res = stream.next().await.unwrap().unwrap(); + last_height = match res.response.unwrap().blob_type.unwrap() { + movement_da_sequencer_proto::blob_response::BlobType::Blockv1(inner) => inner.height, + _ => panic!("unexpected blob type"), + }; + } + + // Shut down first server + let _ = shutdown_tx.send(()); + tokio::time::sleep(Duration::from_millis(500)).await; + + // Second server: send blocks 2 and 3 + let heights_2 = Arc::new(Mutex::new(vec![2, 3])); + let mock_2 = ReconnectableMock { heights: heights_2.clone() }; + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + tokio::spawn(async move { + Server::builder() + .add_service(DaSequencerNodeServiceServer::new(mock_2)) + .serve_with_incoming(TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + // Resume stream from last_height + 1 + let mut stream = client + .stream_read_from_height(StreamReadFromHeightRequest { height: last_height + 1 }) + .await + .unwrap(); + + let res = stream.next().await.unwrap().unwrap(); + let new_height = match res.response.unwrap().blob_type.unwrap() { + movement_da_sequencer_proto::blob_response::BlobType::Blockv1(inner) => inner.height, + _ => panic!("unexpected blob type"), + }; + assert_eq!(new_height, last_height + 1, "Client did not resume at last height + 1"); +}