Skip to content

Commit 5fb9b52

Browse files
authored
RUST-360 Guard connection establishment by connectTimeoutMS (#743)
This fixes a regression that was introduced in #721.
1 parent f4a3806 commit 5fb9b52

File tree

3 files changed

+37
-19
lines changed

3 files changed

+37
-19
lines changed

src/cmap/establish/mod.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
pub(super) mod handshake;
22

3+
use std::time::Duration;
4+
35
use self::handshake::{Handshaker, HandshakerOptions};
46
use super::{
57
conn::{ConnectionGeneration, LoadBalancedGeneration, PendingConnection},
@@ -13,7 +15,7 @@ use crate::{
1315
},
1416
error::{Error as MongoError, ErrorKind, Result},
1517
hello::HelloReply,
16-
runtime::{AsyncStream, HttpClient, TlsConfig},
18+
runtime::{self, stream::DEFAULT_CONNECT_TIMEOUT, AsyncStream, HttpClient, TlsConfig},
1719
sdam::HandshakePhase,
1820
};
1921

@@ -26,11 +28,14 @@ pub(crate) struct ConnectionEstablisher {
2628

2729
/// Cached configuration needed to create TLS connections, if needed.
2830
tls_config: Option<TlsConfig>,
31+
32+
connect_timeout: Duration,
2933
}
3034

3135
pub(crate) struct EstablisherOptions {
3236
handshake_options: HandshakerOptions,
3337
tls_options: Option<TlsOptions>,
38+
connect_timeout: Option<Duration>,
3439
}
3540

3641
impl EstablisherOptions {
@@ -44,6 +49,7 @@ impl EstablisherOptions {
4449
load_balanced: opts.load_balanced.unwrap_or(false),
4550
},
4651
tls_options: opts.tls_options(),
52+
connect_timeout: opts.connect_timeout,
4753
}
4854
}
4955
}
@@ -59,14 +65,25 @@ impl ConnectionEstablisher {
5965
None
6066
};
6167

68+
let connect_timeout = match options.connect_timeout {
69+
Some(d) if d.is_zero() => Duration::MAX,
70+
Some(d) => d,
71+
None => DEFAULT_CONNECT_TIMEOUT,
72+
};
73+
6274
Ok(Self {
6375
handshaker,
6476
tls_config,
77+
connect_timeout,
6578
})
6679
}
6780

6881
async fn make_stream(&self, address: ServerAddress) -> Result<AsyncStream> {
69-
AsyncStream::connect(address, self.tls_config.as_ref()).await
82+
runtime::timeout(
83+
self.connect_timeout,
84+
AsyncStream::connect(address, self.tls_config.as_ref()),
85+
)
86+
.await?
7087
}
7188

7289
/// Establishes a connection.

src/runtime/mod.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,13 @@ pub(crate) async fn delay_for(delay: Duration) {
109109

110110
#[cfg(feature = "async-std-runtime")]
111111
{
112-
async_std::task::sleep(delay).await
112+
// This avoids a panic in async-std when the provided duration is too large.
113+
// See: https://github.com/async-rs/async-std/issues/1037.
114+
if delay == Duration::MAX {
115+
std::future::pending().await
116+
} else {
117+
async_std::task::sleep(delay).await
118+
}
113119
}
114120
}
115121

@@ -124,9 +130,15 @@ pub(crate) async fn timeout<F: Future>(timeout: Duration, future: F) -> Result<F
124130

125131
#[cfg(feature = "async-std-runtime")]
126132
{
127-
async_std::future::timeout(timeout, future)
128-
.await
129-
.map_err(|_| std::io::ErrorKind::TimedOut.into())
133+
// This avoids a panic on async-std when the provided duration is too large.
134+
// See: https://github.com/async-rs/async-std/issues/1037.
135+
if timeout == Duration::MAX {
136+
Ok(future.await)
137+
} else {
138+
async_std::future::timeout(timeout, future)
139+
.await
140+
.map_err(|_| std::io::ErrorKind::TimedOut.into())
141+
}
130142
}
131143
}
132144

src/sdam/monitor.rs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::{
2-
future,
32
sync::Arc,
43
time::{Duration, Instant},
54
};
@@ -185,7 +184,7 @@ impl Monitor {
185184
});
186185

187186
let heartbeat_frequency = self.heartbeat_frequency();
188-
let timeout = if self.connect_timeout().as_millis() == 0 {
187+
let timeout = if self.connect_timeout().is_zero() {
189188
// If connectTimeoutMS = 0, then the socket timeout for monitoring is unlimited.
190189
Duration::MAX
191190
} else if self.topology_version.is_some() {
@@ -198,16 +197,6 @@ impl Monitor {
198197
// Otherwise, just use connectTimeoutMS.
199198
self.connect_timeout()
200199
};
201-
let timeout_future = async {
202-
// If timeout is infinite, don't bother creating a delay future.
203-
// This also avoids a panic on async-std when the provided duration is too large.
204-
// See: https://github.com/async-rs/async-std/issues/1037.
205-
if timeout == Duration::MAX {
206-
future::pending().await
207-
} else {
208-
runtime::delay_for(timeout).await
209-
}
210-
};
211200

212201
let execute_hello = async {
213202
match self.connection {
@@ -268,7 +257,7 @@ impl Monitor {
268257
};
269258
HelloResult::Cancelled { reason: reason_error }
270259
}
271-
_ = timeout_future => {
260+
_ = runtime::delay_for(timeout) => {
272261
HelloResult::Err(Error::network_timeout())
273262
}
274263
};

0 commit comments

Comments
 (0)