diff --git a/linera-rpc/src/grpc/client.rs b/linera-rpc/src/grpc/client.rs index 03f25348e4b0..15a2aa1967e9 100644 --- a/linera-rpc/src/grpc/client.rs +++ b/linera-rpc/src/grpc/client.rs @@ -1,7 +1,7 @@ // Copyright (c) Zefchain Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use std::{fmt, future::Future, iter}; +use std::{fmt, future::Future, iter, sync::Arc}; use futures::{future, stream, StreamExt}; use linera_base::{ @@ -29,6 +29,7 @@ use tracing::{debug, info, instrument, warn, Level}; use super::{ api::{self, validator_node_client::ValidatorNodeClient, SubscriptionRequest}, + pool::GrpcConnectionPool, transport, GRPC_MAX_MESSAGE_SIZE, }; use crate::{ @@ -39,7 +40,7 @@ use crate::{ #[derive(Clone)] pub struct GrpcClient { address: String, - client: ValidatorNodeClient, + pool: Arc, retry_delay: Duration, max_retries: u32, } @@ -47,49 +48,66 @@ pub struct GrpcClient { impl GrpcClient { pub fn new( address: String, - channel: transport::Channel, + pool: Arc, retry_delay: Duration, max_retries: u32, - ) -> Self { - let client = ValidatorNodeClient::new(channel) - .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE) - .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE); - Self { + ) -> Result { + // Just verify we can get a channel to this address + let _ = pool.channel(address.clone())?; + Ok(Self { address, - client, + pool, retry_delay, max_retries, - } + }) } pub fn address(&self) -> &str { &self.address } + fn make_client(&self) -> Result, super::GrpcError> { + let channel = self.pool.channel(self.address.clone())?; + Ok(ValidatorNodeClient::new(channel) + .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE) + .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE)) + } + /// Returns whether this gRPC status means the server stream should be reconnected to, or not. /// Logs a warning on unexpected status codes. - fn is_retryable(status: &Status) -> bool { + fn is_retryable_needs_reconnect(status: &Status) -> (bool, bool) { match status.code() { Code::DeadlineExceeded | Code::Aborted | Code::Unavailable | Code::Unknown => { info!("gRPC request interrupted: {}; retrying", status); - true + (true, false) } Code::Ok | Code::Cancelled | Code::ResourceExhausted => { info!("Unexpected gRPC status: {}; retrying", status); - true + (true, false) + } + Code::NotFound => (false, false), // This code is used if e.g. the validator is missing blobs. + Code::Internal => { + let error_string = status.to_string(); + if error_string.contains("GoAway") && error_string.contains("max_age") { + info!( + "gRPC connection hit max_age and got a GoAway: {}; reconnecting then retrying", + status + ); + return (true, true); + } + info!("Unexpected gRPC status: {}", status); + (false, false) } - Code::NotFound => false, // This code is used if e.g. the validator is missing blobs. Code::InvalidArgument | Code::AlreadyExists | Code::PermissionDenied | Code::FailedPrecondition | Code::OutOfRange | Code::Unimplemented - | Code::Internal | Code::DataLoss | Code::Unauthenticated => { info!("Unexpected gRPC status: {}", status); - false + (false, false) } } } @@ -109,15 +127,36 @@ impl GrpcClient { let request_inner = request.try_into().map_err(|_| NodeError::GrpcError { error: "could not convert request to proto".to_string(), })?; + + let mut reconnected = false; loop { - match f(self.client.clone(), Request::new(request_inner.clone())).await { - Err(s) if Self::is_retryable(&s) && retry_count < self.max_retries => { - let delay = self.retry_delay.saturating_mul(retry_count); - retry_count += 1; - linera_base::time::timer::sleep(delay).await; - continue; + // Create client on-demand for each attempt + let client = match self.make_client() { + Ok(client) => client, + Err(e) => { + return Err(NodeError::GrpcError { + error: format!("Failed to create client: {}", e), + }); } + }; + + match f(client, Request::new(request_inner.clone())).await { Err(s) => { + let (is_retryable, needs_reconnect) = Self::is_retryable_needs_reconnect(&s); + if is_retryable && retry_count < self.max_retries { + // If this error indicates we need a connection refresh and we haven't already tried, do it + if needs_reconnect && !reconnected { + info!("Connection error detected, invalidating channel: {}", s); + self.pool.invalidate_channel(&self.address); + reconnected = true; + } + + let delay = self.retry_delay.saturating_mul(retry_count); + retry_count += 1; + linera_base::time::timer::sleep(delay).await; + continue; + } + return Err(NodeError::GrpcError { error: format!("remote request [{handler}] failed with status: {s:?}"), }); @@ -270,32 +309,56 @@ impl ValidatorNode for GrpcClient { let subscription_request = SubscriptionRequest { chain_ids: chains.into_iter().map(|chain| chain.into()).collect(), }; - let mut client = self.client.clone(); + let pool = self.pool.clone(); + let address = self.address.clone(); // Make the first connection attempt before returning from this method. - let mut stream = Some( + let mut stream = Some({ + let mut client = self + .make_client() + .map_err(|e| NodeError::SubscriptionFailed { + status: format!("Failed to create client: {}", e), + })?; client .subscribe(subscription_request.clone()) .await .map_err(|status| NodeError::SubscriptionFailed { status: status.to_string(), })? - .into_inner(), - ); + .into_inner() + }); // A stream of `Result` that keeps calling // `client.subscribe(request)` endlessly and without delay. let endlessly_retrying_notification_stream = stream::unfold((), move |()| { - let mut client = client.clone(); + let pool = pool.clone(); + let address = address.clone(); let subscription_request = subscription_request.clone(); let mut stream = stream.take(); async move { let stream = if let Some(stream) = stream.take() { future::Either::Right(stream) } else { - match client.subscribe(subscription_request.clone()).await { - Err(err) => future::Either::Left(stream::iter(iter::once(Err(err)))), - Ok(response) => future::Either::Right(response.into_inner()), + // Create a new client for each reconnection attempt + match pool.channel(address.clone()) { + Ok(channel) => { + let mut client = ValidatorNodeClient::new(channel) + .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE) + .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE); + match client.subscribe(subscription_request.clone()).await { + Err(err) => { + future::Either::Left(stream::iter(iter::once(Err(err)))) + } + Ok(response) => future::Either::Right(response.into_inner()), + } + } + Err(e) => { + let status = tonic::Status::unavailable(format!( + "Failed to create channel: {}", + e + )); + future::Either::Left(stream::iter(iter::once(Err(status)))) + } } }; Some((stream, ())) @@ -319,7 +382,9 @@ impl ValidatorNode for GrpcClient { return future::Either::Left(future::ready(true)); }; - if !span.in_scope(|| Self::is_retryable(status)) || retry_count >= max_retries { + let (is_retryable, _) = + span.in_scope(|| Self::is_retryable_needs_reconnect(status)); + if !is_retryable || retry_count >= max_retries { return future::Either::Left(future::ready(false)); } let delay = retry_delay.saturating_mul(retry_count); diff --git a/linera-rpc/src/grpc/node_provider.rs b/linera-rpc/src/grpc/node_provider.rs index 09078d4f64d9..c820fa0074b9 100644 --- a/linera-rpc/src/grpc/node_provider.rs +++ b/linera-rpc/src/grpc/node_provider.rs @@ -1,7 +1,7 @@ // Copyright (c) Zefchain Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use std::str::FromStr as _; +use std::{str::FromStr as _, sync::Arc}; use linera_base::time::Duration; use linera_core::node::{NodeError, ValidatorNodeProvider}; @@ -15,7 +15,7 @@ use crate::{ #[derive(Clone)] pub struct GrpcNodeProvider { - pool: GrpcConnectionPool, + pool: Arc, retry_delay: Duration, max_retries: u32, } @@ -25,7 +25,7 @@ impl GrpcNodeProvider { let transport_options = transport::Options::from(&options); let retry_delay = options.retry_delay; let max_retries = options.max_retries; - let pool = GrpcConnectionPool::new(transport_options); + let pool = Arc::new(GrpcConnectionPool::new(transport_options)); Self { pool, retry_delay, @@ -44,18 +44,15 @@ impl ValidatorNodeProvider for GrpcNodeProvider { } })?; let http_address = network.http_address(); - let channel = - self.pool - .channel(http_address.clone()) - .map_err(|error| NodeError::GrpcError { - error: format!("error creating channel: {}", error), - })?; - - Ok(GrpcClient::new( + + GrpcClient::new( http_address, - channel, + self.pool.clone(), self.retry_delay, self.max_retries, - )) + ) + .map_err(|error| NodeError::GrpcError { + error: format!("error creating client: {}", error), + }) } } diff --git a/linera-rpc/src/grpc/pool.rs b/linera-rpc/src/grpc/pool.rs index 21f136e2cfc6..355d450a6a22 100644 --- a/linera-rpc/src/grpc/pool.rs +++ b/linera-rpc/src/grpc/pool.rs @@ -32,7 +32,7 @@ impl GrpcConnectionPool { /// Obtains a channel for the current address. Either clones an existing one (thereby /// reusing the connection), or creates one if needed. New channels do not create a - /// connection immediately. + /// connection immediately and will automatically reconnect when needed. pub fn channel(&self, address: String) -> Result { let pinned = self.channels.pin(); if let Some(channel) = pinned.get(&address) { @@ -41,4 +41,10 @@ impl GrpcConnectionPool { let channel = transport::create_channel(address.clone(), &self.options)?; Ok(pinned.get_or_insert(address, channel).clone()) } + + /// Removes a channel from the pool, forcing a new connection to be created on the next request. + /// This should be called when a channel is known to be broken (e.g., received GOAWAY). + pub fn invalidate_channel(&self, address: &str) { + self.channels.pin().remove(address); + } } diff --git a/linera-rpc/tests/transport.rs b/linera-rpc/tests/transport.rs index cbf9db66c405..e8d6b93f6ff1 100644 --- a/linera-rpc/tests/transport.rs +++ b/linera-rpc/tests/transport.rs @@ -11,10 +11,7 @@ wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); async fn client() { use linera_base::time::Duration; use linera_core::node::ValidatorNode as _; - use linera_rpc::grpc::{ - transport::{create_channel, Options}, - GrpcClient, - }; + use linera_rpc::grpc::{transport::Options, GrpcClient}; let retry_delay = Duration::from_millis(100); let max_retries = 5; @@ -23,9 +20,8 @@ async fn client() { connect_timeout: Some(Duration::from_millis(100)), timeout: Some(Duration::from_millis(100)), }; - let channel = create_channel(address.clone(), &options).unwrap(); - let _ = GrpcClient::new(address, channel, retry_delay, max_retries) - .get_version_info() - .await - .unwrap(); + + let pool = std::sync::Arc::new(linera_rpc::grpc::pool::GrpcConnectionPool::new(options)); + let client = GrpcClient::new(address, pool, retry_delay, max_retries).unwrap(); + let _ = client.get_version_info().await.unwrap(); }