Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ tracing = { version = "0.1.36", optional = true }
typed-builder = "0.10.0"
webpki-roots = "0.25.2"
zstd = { version = "0.11.2", optional = true }
pin-project = "1.1.7"

[dependencies.pbkdf2]
version = "0.11.0"
Expand Down
2 changes: 1 addition & 1 deletion src/action/client_options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl ClientOptions {
/// * `retryWrites`: not yet implemented
/// * `retryReads`: maps to the `retry_reads` field
/// * `serverSelectionTimeoutMS`: maps to the `server_selection_timeout` field
/// * `socketTimeoutMS`: unsupported, does not map to any field
/// * `socketTimeoutMS`: maps to `socket_timeout` field
/// * `ssl`: an alias of the `tls` option
/// * `tls`: maps to the TLS variant of the `tls` field`.
/// * `tlsInsecure`: relaxes the TLS constraints on connections being made; currently is just
Expand Down
9 changes: 6 additions & 3 deletions src/client/csfle/state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,12 @@ impl CryptExecutor {
.and_then(|tls| tls.get(&provider))
.cloned()
.unwrap_or_default();
let mut stream =
AsyncStream::connect(addr, Some(&TlsConfig::new(tls_options)?))
.await?;
let mut stream = AsyncStream::connect(
addr,
Some(&TlsConfig::new(tls_options)?),
None,
)
.await?;
stream.write_all(kms_ctx.message()?).await?;
let mut buf = vec![0];
while kms_ctx.bytes_needed() > 0 {
Expand Down
2 changes: 0 additions & 2 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,6 @@ pub struct ClientOptions {
/// Overrides the default "mongodb" service name for SRV lookup in both discovery and polling
pub srv_service_name: Option<String>,

#[builder(setter(skip))]
#[derive_where(skip(Debug))]
pub(crate) socket_timeout: Option<Duration>,

/// The TLS configuration for the Client to use in its connections with the server.
Expand Down
2 changes: 1 addition & 1 deletion src/cmap/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl Connection {

pub(crate) fn take(&mut self) -> Self {
Self {
stream: std::mem::replace(&mut self.stream, BufStream::new(AsyncStream::Null)),
stream: std::mem::replace(&mut self.stream, BufStream::new(AsyncStream::null())),
stream_description: self.stream_description.take(),
address: self.address.clone(),
id: self.id,
Expand Down
8 changes: 7 additions & 1 deletion src/cmap/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pub(crate) struct ConnectionEstablisher {

connect_timeout: Duration,

socket_timeout: Option<Duration>,

#[cfg(test)]
test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
}
Expand All @@ -44,6 +46,7 @@ pub(crate) struct EstablisherOptions {
handshake_options: HandshakerOptions,
tls_options: Option<TlsOptions>,
connect_timeout: Option<Duration>,
socket_timeout: Option<Duration>,
#[cfg(test)]
pub(crate) test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
}
Expand All @@ -65,6 +68,7 @@ impl EstablisherOptions {
},
tls_options: opts.tls_options(),
connect_timeout: opts.connect_timeout,
socket_timeout: opts.socket_timeout,
#[cfg(test)]
test_patch_reply: None,
}
Expand All @@ -87,11 +91,13 @@ impl ConnectionEstablisher {
Some(d) => d,
None => DEFAULT_CONNECT_TIMEOUT,
};
let socket_timeout = options.socket_timeout;

Ok(Self {
handshaker,
tls_config,
connect_timeout,
socket_timeout,
#[cfg(test)]
test_patch_reply: options.test_patch_reply,
})
Expand All @@ -100,7 +106,7 @@ impl ConnectionEstablisher {
async fn make_stream(&self, address: ServerAddress) -> Result<AsyncStream> {
runtime::timeout(
self.connect_timeout,
AsyncStream::connect(address, self.tls_config.as_ref()),
AsyncStream::connect(address, self.tls_config.as_ref(), self.socket_timeout),
)
.await?
}
Expand Down
175 changes: 127 additions & 48 deletions src/runtime/stream.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::{
net::SocketAddr,
ops::DerefMut,
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use tokio::{io::AsyncWrite, net::TcpStream};
use pin_project::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};

use crate::{
error::{Error, ErrorKind, Result},
Expand All @@ -24,26 +27,45 @@ pub(crate) const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const KEEPALIVE_TIME: Duration = Duration::from_secs(120);

/// An async stream possibly using TLS.
#[derive(Debug)]
#[pin_project]
pub(crate) struct AsyncStream {
#[pin]
kind: AsyncStreamKind,
}

#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum AsyncStream {
#[pin_project(project = AsyncStreamProj)]
enum AsyncStreamKind {
Null,

/// A basic TCP connection to the server.
Tcp(TcpStream),
Tcp(#[pin] TcpStream),

/// A TLS connection over TCP.
Tls(TlsStream),
Tls(#[pin] TlsStream),

/// A Unix domain socket connection.
#[cfg(unix)]
Unix(tokio::net::UnixStream),
Unix(#[pin] tokio::net::UnixStream),
}

impl From<AsyncStreamKind> for AsyncStream {
fn from(kind: AsyncStreamKind) -> Self {
AsyncStream { kind }
}
}

impl AsyncStream {
pub(crate) fn null() -> Self {
AsyncStreamKind::Null.into()
}

pub(crate) async fn connect(
address: ServerAddress,
tls_cfg: Option<&TlsConfig>,
socket_timeout: Option<Duration>,
) -> Result<Self> {
match &address {
ServerAddress::Tcp { host, .. } => {
Expand All @@ -54,23 +76,29 @@ impl AsyncStream {
}
.into());
}
let inner = tcp_connect(resolved).await?;
let inner = tcp_connect(resolved, socket_timeout).await?;

// If there are TLS options, wrap the inner stream in an AsyncTlsStream.
match tls_cfg {
Some(cfg) => Ok(AsyncStream::Tls(tls_connect(host, inner, cfg).await?)),
None => Ok(AsyncStream::Tcp(inner)),
Some(cfg) => {
Ok(AsyncStreamKind::Tls(tls_connect(host, inner, cfg).await?).into())
}
None => Ok(AsyncStreamKind::Tcp(inner).into()),
}
}
#[cfg(unix)]
ServerAddress::Unix { path } => Ok(AsyncStream::Unix(
ServerAddress::Unix { path } => Ok(AsyncStreamKind::Unix(
tokio::net::UnixStream::connect(path.as_path()).await?,
)),
)
.into()),
}
}
}

async fn tcp_try_connect(address: &SocketAddr) -> Result<TcpStream> {
async fn tcp_try_connect(
address: &SocketAddr,
socket_timeout: Option<Duration>,
) -> Result<TcpStream> {
let stream = TcpStream::connect(address).await?;
stream.set_nodelay(true)?;

Expand All @@ -80,11 +108,16 @@ async fn tcp_try_connect(address: &SocketAddr) -> Result<TcpStream> {
let conf = socket2::TcpKeepalive::new().with_time(KEEPALIVE_TIME);
socket.set_tcp_keepalive(&conf)?;
}
socket.set_write_timeout(socket_timeout)?;
socket.set_read_timeout(socket_timeout)?;
let std_stream = std::net::TcpStream::from(socket);
Ok(TcpStream::from_std(std_stream)?)
}

pub(crate) async fn tcp_connect(resolved: Vec<SocketAddr>) -> Result<TcpStream> {
pub(crate) async fn tcp_connect(
resolved: Vec<SocketAddr>,
socket_timeout: Option<Duration>,
) -> Result<TcpStream> {
// "Happy Eyeballs": try addresses in parallel, interleaving IPv6 and IPv4, preferring IPv6.
// Based on the implementation in https://codeberg.org/KMK/happy-eyeballs.
let (addrs_v6, addrs_v4): (Vec<_>, Vec<_>) = resolved
Expand All @@ -109,7 +142,7 @@ pub(crate) async fn tcp_connect(resolved: Vec<SocketAddr>) -> Result<TcpStream>
let mut attempts = tokio::task::JoinSet::new();
let mut connect_error = None;
'spawn: for a in socket_addrs {
attempts.spawn(async move { tcp_try_connect(&a).await });
attempts.spawn(async move { tcp_try_connect(&a, socket_timeout).await });
let sleep = tokio::time::sleep(CONNECTION_ATTEMPT_DELAY);
tokio::pin!(sleep); // required for select!
while !attempts.is_empty() {
Expand Down Expand Up @@ -163,54 +196,100 @@ fn interleave<T>(left: Vec<T>, right: Vec<T>) -> Vec<T> {
out
}

impl tokio::io::AsyncRead for AsyncStream {
impl AsyncRead for AsyncStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.project().kind.poll_read(cx, buf)
}
}

impl AsyncWrite for AsyncStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
self.project().kind.poll_write(cx, buf)
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
self.project().kind.poll_flush(cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
self.project().kind.poll_shutdown(cx)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[futures_io::IoSlice<'_>],
) -> Poll<std::result::Result<usize, std::io::Error>> {
self.project().kind.poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.kind.is_write_vectored()
}
}

impl AsyncRead for AsyncStreamKind {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(())),
Self::Tcp(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf),
Self::Tls(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf),
match self.project() {
AsyncStreamProj::Null => Poll::Ready(Ok(())),
AsyncStreamProj::Tcp(inner) => inner.poll_read(cx, buf),
AsyncStreamProj::Tls(inner) => inner.poll_read(cx, buf),
#[cfg(unix)]
Self::Unix(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf),
AsyncStreamProj::Unix(inner) => inner.poll_read(cx, buf),
}
}
}

impl AsyncWrite for AsyncStream {
impl AsyncWrite for AsyncStreamKind {
fn poll_write(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(0)),
Self::Tcp(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
Self::Tls(ref mut inner) => Pin::new(inner).poll_write(cx, buf),
match self.project() {
AsyncStreamProj::Null => Poll::Ready(Ok(0)),
AsyncStreamProj::Tcp(inner) => inner.poll_write(cx, buf),
AsyncStreamProj::Tls(inner) => inner.poll_write(cx, buf),
#[cfg(unix)]
Self::Unix(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
AsyncStreamProj::Unix(inner) => inner.poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(())),
Self::Tcp(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
Self::Tls(ref mut inner) => Pin::new(inner).poll_flush(cx),
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.project() {
AsyncStreamProj::Null => Poll::Ready(Ok(())),
AsyncStreamProj::Tcp(inner) => inner.poll_flush(cx),
AsyncStreamProj::Tls(inner) => inner.poll_flush(cx),
#[cfg(unix)]
Self::Unix(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
AsyncStreamProj::Unix(inner) => inner.poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(())),
Self::Tcp(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
Self::Tls(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.project() {
AsyncStreamProj::Null => Poll::Ready(Ok(())),
AsyncStreamProj::Tcp(inner) => inner.poll_shutdown(cx),
AsyncStreamProj::Tls(inner) => inner.poll_shutdown(cx),
#[cfg(unix)]
Self::Unix(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
AsyncStreamProj::Unix(inner) => inner.poll_shutdown(cx),
}
}

Expand All @@ -219,22 +298,22 @@ impl AsyncWrite for AsyncStream {
cx: &mut Context<'_>,
bufs: &[futures_io::IoSlice<'_>],
) -> Poll<std::result::Result<usize, std::io::Error>> {
match self.get_mut() {
Self::Null => Poll::Ready(Ok(0)),
Self::Tcp(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
Self::Tls(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
match self.project() {
AsyncStreamProj::Null => Poll::Ready(Ok(0)),
AsyncStreamProj::Tcp(inner) => inner.poll_write_vectored(cx, bufs),
AsyncStreamProj::Tls(inner) => inner.poll_write_vectored(cx, bufs),
#[cfg(unix)]
Self::Unix(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
AsyncStreamProj::Unix(inner) => inner.poll_write_vectored(cx, bufs),
}
}

fn is_write_vectored(&self) -> bool {
match self {
Self::Null => false,
Self::Tcp(ref inner) => inner.is_write_vectored(),
Self::Tls(ref inner) => inner.is_write_vectored(),
Self::Tcp(inner) => inner.is_write_vectored(),
Self::Tls(inner) => inner.is_write_vectored(),
#[cfg(unix)]
Self::Unix(ref inner) => inner.is_write_vectored(),
Self::Unix(inner) => inner.is_write_vectored(),
}
}
}
Loading