diff --git a/Cargo.toml b/Cargo.toml index 78f7fa08a..f5b22d81f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -114,6 +114,7 @@ tracing = { version = "0.1.36", optional = true } typed-builder = "0.10.0" webpki-roots = "0.25.2" zstd = { version = "0.11.2", optional = true } +pin-project = "1.1.7" [dependencies.pbkdf2] version = "0.11.0" diff --git a/src/action/client_options.rs b/src/action/client_options.rs index 4c09fd1fb..fe536cd49 100644 --- a/src/action/client_options.rs +++ b/src/action/client_options.rs @@ -44,7 +44,7 @@ impl ClientOptions { /// * `retryWrites`: not yet implemented /// * `retryReads`: maps to the `retry_reads` field /// * `serverSelectionTimeoutMS`: maps to the `server_selection_timeout` field - /// * `socketTimeoutMS`: unsupported, does not map to any field + /// * `socketTimeoutMS`: maps to `socket_timeout` field /// * `ssl`: an alias of the `tls` option /// * `tls`: maps to the TLS variant of the `tls` field`. /// * `tlsInsecure`: relaxes the TLS constraints on connections being made; currently is just diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index 67cb296ec..84390f861 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -189,9 +189,12 @@ impl CryptExecutor { .and_then(|tls| tls.get(&provider)) .cloned() .unwrap_or_default(); - let mut stream = - AsyncStream::connect(addr, Some(&TlsConfig::new(tls_options)?)) - .await?; + let mut stream = AsyncStream::connect( + addr, + Some(&TlsConfig::new(tls_options)?), + None, + ) + .await?; stream.write_all(kms_ctx.message()?).await?; let mut buf = vec![0]; while kms_ctx.bytes_needed() > 0 { diff --git a/src/client/options.rs b/src/client/options.rs index e9bed02f1..a8b2b2cac 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -567,8 +567,6 @@ pub struct ClientOptions { /// Overrides the default "mongodb" service name for SRV lookup in both discovery and polling pub srv_service_name: Option, - #[builder(setter(skip))] - #[derive_where(skip(Debug))] pub(crate) socket_timeout: Option, /// The TLS configuration for the Client to use in its connections with the server. diff --git a/src/cmap/conn.rs b/src/cmap/conn.rs index 630c415b3..43ac8b5c2 100644 --- a/src/cmap/conn.rs +++ b/src/cmap/conn.rs @@ -128,7 +128,7 @@ impl Connection { pub(crate) fn take(&mut self) -> Self { Self { - stream: std::mem::replace(&mut self.stream, BufStream::new(AsyncStream::Null)), + stream: std::mem::replace(&mut self.stream, BufStream::new(AsyncStream::null())), stream_description: self.stream_description.take(), address: self.address.clone(), id: self.id, diff --git a/src/cmap/establish.rs b/src/cmap/establish.rs index ed44160ed..46700e9c0 100644 --- a/src/cmap/establish.rs +++ b/src/cmap/establish.rs @@ -36,6 +36,8 @@ pub(crate) struct ConnectionEstablisher { connect_timeout: Duration, + socket_timeout: Option, + #[cfg(test)] test_patch_reply: Option)>, } @@ -44,6 +46,7 @@ pub(crate) struct EstablisherOptions { handshake_options: HandshakerOptions, tls_options: Option, connect_timeout: Option, + socket_timeout: Option, #[cfg(test)] pub(crate) test_patch_reply: Option)>, } @@ -65,6 +68,7 @@ impl EstablisherOptions { }, tls_options: opts.tls_options(), connect_timeout: opts.connect_timeout, + socket_timeout: opts.socket_timeout, #[cfg(test)] test_patch_reply: None, } @@ -87,11 +91,13 @@ impl ConnectionEstablisher { Some(d) => d, None => DEFAULT_CONNECT_TIMEOUT, }; + let socket_timeout = options.socket_timeout; Ok(Self { handshaker, tls_config, connect_timeout, + socket_timeout, #[cfg(test)] test_patch_reply: options.test_patch_reply, }) @@ -100,7 +106,7 @@ impl ConnectionEstablisher { async fn make_stream(&self, address: ServerAddress) -> Result { runtime::timeout( self.connect_timeout, - AsyncStream::connect(address, self.tls_config.as_ref()), + AsyncStream::connect(address, self.tls_config.as_ref(), self.socket_timeout), ) .await? } diff --git a/src/runtime/stream.rs b/src/runtime/stream.rs index 8d9ef9762..8ef74da62 100644 --- a/src/runtime/stream.rs +++ b/src/runtime/stream.rs @@ -1,12 +1,15 @@ use std::{ net::SocketAddr, - ops::DerefMut, pin::Pin, task::{Context, Poll}, time::Duration, }; -use tokio::{io::AsyncWrite, net::TcpStream}; +use pin_project::pin_project; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, +}; use crate::{ error::{Error, ErrorKind, Result}, @@ -24,26 +27,45 @@ pub(crate) const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); const KEEPALIVE_TIME: Duration = Duration::from_secs(120); /// An async stream possibly using TLS. +#[derive(Debug)] +#[pin_project] +pub(crate) struct AsyncStream { + #[pin] + kind: AsyncStreamKind, +} + #[allow(clippy::large_enum_variant)] #[derive(Debug)] -pub(crate) enum AsyncStream { +#[pin_project(project = AsyncStreamProj)] +enum AsyncStreamKind { Null, /// A basic TCP connection to the server. - Tcp(TcpStream), + Tcp(#[pin] TcpStream), /// A TLS connection over TCP. - Tls(TlsStream), + Tls(#[pin] TlsStream), /// A Unix domain socket connection. #[cfg(unix)] - Unix(tokio::net::UnixStream), + Unix(#[pin] tokio::net::UnixStream), +} + +impl From for AsyncStream { + fn from(kind: AsyncStreamKind) -> Self { + AsyncStream { kind } + } } impl AsyncStream { + pub(crate) fn null() -> Self { + AsyncStreamKind::Null.into() + } + pub(crate) async fn connect( address: ServerAddress, tls_cfg: Option<&TlsConfig>, + socket_timeout: Option, ) -> Result { match &address { ServerAddress::Tcp { host, .. } => { @@ -54,23 +76,29 @@ impl AsyncStream { } .into()); } - let inner = tcp_connect(resolved).await?; + let inner = tcp_connect(resolved, socket_timeout).await?; // If there are TLS options, wrap the inner stream in an AsyncTlsStream. match tls_cfg { - Some(cfg) => Ok(AsyncStream::Tls(tls_connect(host, inner, cfg).await?)), - None => Ok(AsyncStream::Tcp(inner)), + Some(cfg) => { + Ok(AsyncStreamKind::Tls(tls_connect(host, inner, cfg).await?).into()) + } + None => Ok(AsyncStreamKind::Tcp(inner).into()), } } #[cfg(unix)] - ServerAddress::Unix { path } => Ok(AsyncStream::Unix( + ServerAddress::Unix { path } => Ok(AsyncStreamKind::Unix( tokio::net::UnixStream::connect(path.as_path()).await?, - )), + ) + .into()), } } } -async fn tcp_try_connect(address: &SocketAddr) -> Result { +async fn tcp_try_connect( + address: &SocketAddr, + socket_timeout: Option, +) -> Result { let stream = TcpStream::connect(address).await?; stream.set_nodelay(true)?; @@ -80,11 +108,16 @@ async fn tcp_try_connect(address: &SocketAddr) -> Result { let conf = socket2::TcpKeepalive::new().with_time(KEEPALIVE_TIME); socket.set_tcp_keepalive(&conf)?; } + socket.set_write_timeout(socket_timeout)?; + socket.set_read_timeout(socket_timeout)?; let std_stream = std::net::TcpStream::from(socket); Ok(TcpStream::from_std(std_stream)?) } -pub(crate) async fn tcp_connect(resolved: Vec) -> Result { +pub(crate) async fn tcp_connect( + resolved: Vec, + socket_timeout: Option, +) -> Result { // "Happy Eyeballs": try addresses in parallel, interleaving IPv6 and IPv4, preferring IPv6. // Based on the implementation in https://codeberg.org/KMK/happy-eyeballs. let (addrs_v6, addrs_v4): (Vec<_>, Vec<_>) = resolved @@ -109,7 +142,7 @@ pub(crate) async fn tcp_connect(resolved: Vec) -> Result let mut attempts = tokio::task::JoinSet::new(); let mut connect_error = None; 'spawn: for a in socket_addrs { - attempts.spawn(async move { tcp_try_connect(&a).await }); + attempts.spawn(async move { tcp_try_connect(&a, socket_timeout).await }); let sleep = tokio::time::sleep(CONNECTION_ATTEMPT_DELAY); tokio::pin!(sleep); // required for select! while !attempts.is_empty() { @@ -163,54 +196,100 @@ fn interleave(left: Vec, right: Vec) -> Vec { out } -impl tokio::io::AsyncRead for AsyncStream { +impl AsyncRead for AsyncStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + self.project().kind.poll_read(cx, buf) + } +} + +impl AsyncWrite for AsyncStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().kind.poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().kind.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().kind.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[futures_io::IoSlice<'_>], + ) -> Poll> { + self.project().kind.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.kind.is_write_vectored() + } +} + +impl AsyncRead for AsyncStreamKind { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { - match self.deref_mut() { - Self::Null => Poll::Ready(Ok(())), - Self::Tcp(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf), - Self::Tls(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf), + match self.project() { + AsyncStreamProj::Null => Poll::Ready(Ok(())), + AsyncStreamProj::Tcp(inner) => inner.poll_read(cx, buf), + AsyncStreamProj::Tls(inner) => inner.poll_read(cx, buf), #[cfg(unix)] - Self::Unix(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf), + AsyncStreamProj::Unix(inner) => inner.poll_read(cx, buf), } } } -impl AsyncWrite for AsyncStream { +impl AsyncWrite for AsyncStreamKind { fn poll_write( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - match self.deref_mut() { - Self::Null => Poll::Ready(Ok(0)), - Self::Tcp(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf), - Self::Tls(ref mut inner) => Pin::new(inner).poll_write(cx, buf), + match self.project() { + AsyncStreamProj::Null => Poll::Ready(Ok(0)), + AsyncStreamProj::Tcp(inner) => inner.poll_write(cx, buf), + AsyncStreamProj::Tls(inner) => inner.poll_write(cx, buf), #[cfg(unix)] - Self::Unix(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf), + AsyncStreamProj::Unix(inner) => inner.poll_write(cx, buf), } } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.deref_mut() { - Self::Null => Poll::Ready(Ok(())), - Self::Tcp(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx), - Self::Tls(ref mut inner) => Pin::new(inner).poll_flush(cx), + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + AsyncStreamProj::Null => Poll::Ready(Ok(())), + AsyncStreamProj::Tcp(inner) => inner.poll_flush(cx), + AsyncStreamProj::Tls(inner) => inner.poll_flush(cx), #[cfg(unix)] - Self::Unix(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx), + AsyncStreamProj::Unix(inner) => inner.poll_flush(cx), } } - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.deref_mut() { - Self::Null => Poll::Ready(Ok(())), - Self::Tcp(ref mut inner) => Pin::new(inner).poll_shutdown(cx), - Self::Tls(ref mut inner) => Pin::new(inner).poll_shutdown(cx), + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + AsyncStreamProj::Null => Poll::Ready(Ok(())), + AsyncStreamProj::Tcp(inner) => inner.poll_shutdown(cx), + AsyncStreamProj::Tls(inner) => inner.poll_shutdown(cx), #[cfg(unix)] - Self::Unix(ref mut inner) => Pin::new(inner).poll_shutdown(cx), + AsyncStreamProj::Unix(inner) => inner.poll_shutdown(cx), } } @@ -219,22 +298,22 @@ impl AsyncWrite for AsyncStream { cx: &mut Context<'_>, bufs: &[futures_io::IoSlice<'_>], ) -> Poll> { - match self.get_mut() { - Self::Null => Poll::Ready(Ok(0)), - Self::Tcp(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs), - Self::Tls(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs), + match self.project() { + AsyncStreamProj::Null => Poll::Ready(Ok(0)), + AsyncStreamProj::Tcp(inner) => inner.poll_write_vectored(cx, bufs), + AsyncStreamProj::Tls(inner) => inner.poll_write_vectored(cx, bufs), #[cfg(unix)] - Self::Unix(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs), + AsyncStreamProj::Unix(inner) => inner.poll_write_vectored(cx, bufs), } } fn is_write_vectored(&self) -> bool { match self { Self::Null => false, - Self::Tcp(ref inner) => inner.is_write_vectored(), - Self::Tls(ref inner) => inner.is_write_vectored(), + Self::Tcp(inner) => inner.is_write_vectored(), + Self::Tls(inner) => inner.is_write_vectored(), #[cfg(unix)] - Self::Unix(ref inner) => inner.is_write_vectored(), + Self::Unix(inner) => inner.is_write_vectored(), } } } diff --git a/src/test/happy_eyeballs.rs b/src/test/happy_eyeballs.rs index 112541309..1cb830016 100644 --- a/src/test/happy_eyeballs.rs +++ b/src/test/happy_eyeballs.rs @@ -8,7 +8,7 @@ const SLOW_V4: u8 = 4; const SLOW_V6: u8 = 6; async fn happy_request(payload: u8) -> (SocketAddr, SocketAddr) { - let mut control = tcp_connect(vec![CONTROL]).await.unwrap(); + let mut control = tcp_connect(vec![CONTROL], None).await.unwrap(); control.write_u8(payload).await.unwrap(); let resp = control.read_u8().await.unwrap(); assert_eq!(resp, 1); @@ -23,7 +23,7 @@ async fn happy_request(payload: u8) -> (SocketAddr, SocketAddr) { #[tokio::test] async fn slow_ipv4() { let (v4_addr, v6_addr) = happy_request(SLOW_V4).await; - let mut conn = tcp_connect(vec![v4_addr, v6_addr]).await.unwrap(); + let mut conn = tcp_connect(vec![v4_addr, v6_addr], None).await.unwrap(); assert!(conn.peer_addr().unwrap().is_ipv6()); let data = conn.read_u8().await.unwrap(); assert_eq!(data, 6); @@ -32,7 +32,7 @@ async fn slow_ipv4() { #[tokio::test] async fn slow_ipv6() { let (v4_addr, v6_addr) = happy_request(SLOW_V6).await; - let mut conn = tcp_connect(vec![v4_addr, v6_addr]).await.unwrap(); + let mut conn = tcp_connect(vec![v4_addr, v6_addr], None).await.unwrap(); assert!(conn.peer_addr().unwrap().is_ipv4()); let data = conn.read_u8().await.unwrap(); assert_eq!(data, 4); diff --git a/src/test/spec/sdam.rs b/src/test/spec/sdam.rs index f1db4f463..33a9c87e8 100644 --- a/src/test/spec/sdam.rs +++ b/src/test/spec/sdam.rs @@ -5,6 +5,7 @@ use bson::{doc, Document}; use crate::{ event::sdam::SdamEvent, hello::LEGACY_HELLO_COMMAND_NAME, + options::ClientOptions, runtime, test::{ get_client_options, @@ -35,9 +36,9 @@ async fn run_unified() { run_unified_tests(&["server-discovery-and-monitoring", "unified"]) .skip_files(&skipped_files) .skip_tests(&[ - // The driver does not support socketTimeoutMS. + // Flaky tests "Reset server and pool after network timeout error during authentication", - "Ignore network timeout error on find", + //"Ignore network timeout error on find", ]) .await; } @@ -230,6 +231,30 @@ async fn rtt_is_updated() { .unwrap(); } +#[tokio::test(flavor = "multi_thread")] +async fn socket_timeout_ms_uri_option() { + let uri = "mongodb+srv://test1.test.build.10gen.cc/?socketTimeoutMs=1"; + let options = ClientOptions::parse(uri).await.unwrap(); + assert_eq!(options.socket_timeout.unwrap().as_millis(), 1); +} + +#[tokio::test(flavor = "multi_thread")] +async fn socket_timeout_ms_client_option() { + let mut options = get_client_options().await.clone(); + options.socket_timeout = Some(Duration::from_millis(1)); + + let client = Client::with_options(options.clone()).unwrap(); + let db = client.database("test"); + let error = db + .run_command(doc! {"ping": 1}) + .await + .expect_err("should fail with socket timeout error"); + let error_description = format!("{}", error); + for host in options.hosts.iter() { + assert!(error_description.contains(format!("{}", host).as_str())); + } +} + /* TODO RUST-1895 enable this #[tokio::test(flavor = "multi_thread")] async fn heartbeat_started_before_socket() { diff --git a/src/test/spec/transactions.rs b/src/test/spec/transactions.rs index ce67ca5d8..4cfe0a9a3 100644 --- a/src/test/spec/transactions.rs +++ b/src/test/spec/transactions.rs @@ -22,7 +22,7 @@ async fn run_unified_base_api() { run_unified_tests(&["transactions", "unified"]) // TODO RUST-1656: unskip these files .skip_files(&["retryable-abort-handshake.json", "retryable-commit-handshake.json"]) - // The driver doesn't support socketTimeoutMS + // Flaky test .skip_tests(&["add RetryableWriteError and UnknownTransactionCommitResult labels to connection errors"]) .await; }