Skip to content

Commit 64e00e8

Browse files
authored
RUST-633 update driver to use tokio 1.0 (#283)
1 parent 3639656 commit 64e00e8

File tree

20 files changed

+128
-254
lines changed

20 files changed

+128
-254
lines changed

.evergreen/aws-ecs-test/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ authors = ["Saghm Rossi <[email protected]>"]
55
edition = "2018"
66

77
[dependencies]
8-
tokio = "0.2.21"
8+
tokio = "1.0.2"
99

1010
[dependencies.mongodb]
1111
path = "../.."

Cargo.toml

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ exclude = [
2121

2222
[features]
2323
default = ["tokio-runtime"]
24-
tokio-runtime = ["tokio/dns", "tokio/macros", "tokio/rt-core", "tokio/tcp", "tokio/rt-threaded", "tokio/time", "reqwest", "serde_bytes"]
25-
async-std-runtime = ["async-std", "async-std/attributes"]
24+
tokio-runtime = ["tokio/macros", "tokio/net", "tokio/rt", "tokio/time", "reqwest", "serde_bytes"]
25+
async-std-runtime = ["async-std", "async-std/attributes", "async-std-resolver", "tokio-util/compat"]
2626
sync = ["async-std-runtime"]
2727

2828
[dependencies]
29-
async-trait = "0.1.24"
29+
async-trait = "0.1.42"
3030
base64 = "0.11.0"
3131
bitflags = "1.1.0"
3232
bson = "1.1.0"
@@ -49,8 +49,8 @@ stringprep = "0.1.2"
4949
strsim = "0.10.0"
5050
take_mut = "0.2.2"
5151
time = "0.1.42"
52-
trust-dns-proto = "0.19.4"
53-
trust-dns-resolver = "0.19.5"
52+
trust-dns-proto = "0.20.0"
53+
trust-dns-resolver = "0.20.0"
5454
typed-builder = "0.4.0"
5555
version_check = "0.9.1"
5656
webpki = "0.21.0"
@@ -60,18 +60,22 @@ webpki-roots = "0.18.0"
6060
version = "1.6.2"
6161
optional = true
6262

63+
[dependencies.async-std-resolver]
64+
version = "0.20.0"
65+
optional = true
66+
6367
[dependencies.pbkdf2]
6468
version = "0.3.0"
6569
default-features = false
6670

6771
[dependencies.reqwest]
68-
version = "0.10.6"
72+
version = "0.11.0"
6973
optional = true
7074
default-features = false
7175
features = ["json", "rustls-tls"]
7276

7377
[dependencies.rustls]
74-
version = "0.17.0"
78+
version = "0.19.0"
7579
features = ["dangerous_configuration"]
7680

7781
[dependencies.serde]
@@ -83,13 +87,17 @@ version = "0.11.5"
8387
optional = true
8488

8589
[dependencies.tokio]
86-
version = "~0.2.18"
90+
version = "1.0.1"
8791
features = ["io-util", "sync", "macros"]
8892

8993
[dependencies.tokio-rustls]
90-
version = "0.13.0"
94+
version = "0.22.0"
9195
features = ["dangerous_configuration"]
9296

97+
[dependencies.tokio-util]
98+
version = "0.6.1"
99+
features = ["io"]
100+
93101
[dependencies.uuid]
94102
version = "0.8.1"
95103
features = ["v4"]

src/client/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ impl Client {
270270
}
271271

272272
let mut topology_change_subscriber =
273-
self.inner.topology.subscribe_to_topology_changes().await;
273+
self.inner.topology.subscribe_to_topology_changes();
274274
self.inner.topology.request_topology_check();
275275

276276
let time_passed = start_time.to(PreciseTime::now());

src/cmap/status.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use crate::RUNTIME;
2-
31
/// Struct used to track the latest status of the pool.
42
#[derive(Clone, Debug)]
53
struct PoolStatus {
@@ -15,11 +13,7 @@ impl Default for PoolStatus {
1513

1614
/// Create a channel for publishing and receiving updates to the pool's generation.
1715
pub(super) fn channel() -> (PoolGenerationPublisher, PoolGenerationSubscriber) {
18-
let (sender, mut receiver) = tokio::sync::watch::channel(Default::default());
19-
// The first call to recv on a watch channel returns immediately with the initial value.
20-
// We use RUNTIME.block_in_place because this is not a truly blocking task, so
21-
// the runtimes don't need to shift things around to ensure scheduling continues normally.
22-
RUNTIME.block_in_place(receiver.recv());
16+
let (sender, receiver) = tokio::sync::watch::channel(Default::default());
2317
(
2418
PoolGenerationPublisher { sender },
2519
PoolGenerationSubscriber { receiver },
@@ -40,7 +34,7 @@ impl PoolGenerationPublisher {
4034
};
4135

4236
// if nobody is listening, this will return an error, which we don't mind.
43-
let _: std::result::Result<_, _> = self.sender.broadcast(new_status);
37+
let _: std::result::Result<_, _> = self.sender.send(new_status);
4438
}
4539
}
4640

@@ -62,10 +56,11 @@ impl PoolGenerationSubscriber {
6256
timeout: std::time::Duration,
6357
) -> Option<u32> {
6458
crate::RUNTIME
65-
.timeout(timeout, self.receiver.recv())
59+
.timeout(timeout, self.receiver.changed())
6660
.await
67-
.ok()
68-
.flatten()
69-
.map(|status| status.generation)
61+
.ok()?
62+
.ok()?;
63+
64+
Some(self.receiver.borrow().generation)
7065
}
7166
}

src/cmap/test/event.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
use serde::{de::Unexpected, Deserialize, Deserializer};
77

88
use crate::{event::cmap::*, options::StreamAddress, RUNTIME};
9-
use tokio::sync::broadcast::{RecvError, SendError};
9+
use tokio::sync::broadcast::error::{RecvError, SendError};
1010

1111
#[derive(Clone, Debug)]
1212
pub struct EventHandler {

src/cmap/test/integration.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ async fn concurrent_connections() {
158158
.expect("disabling fail point should succeed");
159159
}
160160

161-
#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
161+
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
162162
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
163163
#[function_name::named]
164164
async fn connection_error_during_establishment() {
@@ -209,7 +209,7 @@ async fn connection_error_during_establishment() {
209209
.expect("closed event with error reason should have been seen");
210210
}
211211

212-
#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
212+
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
213213
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
214214
#[function_name::named]
215215
async fn connection_error_during_operation() {

src/error.rs

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,12 @@ pub struct Error {
3535
}
3636

3737
impl Error {
38-
pub(crate) fn new(e: Arc<ErrorKind>) -> Error {
39-
Error {
40-
kind: e,
41-
labels: Vec::new(),
42-
}
43-
}
44-
4538
pub(crate) fn pool_cleared_error(address: &StreamAddress) -> Self {
4639
ErrorKind::ConnectionPoolClearedError {
47-
message: format!("Conneciton pool for {} cleared during operation execution", address)
40+
message: format!(
41+
"Connection pool for {} cleared during operation execution",
42+
address
43+
),
4844
}
4945
.into()
5046
}
@@ -68,20 +64,6 @@ impl Error {
6864
Error::authentication_error(mechanism_name, "invalid server response")
6965
}
7066

71-
/// Attempts to get the `std::io::Error` from this `Error`.
72-
/// If there are other references to the underlying `Arc`, or if the `ErrorKind` is not `Io`,
73-
/// then the original error is returned as a custom `std::io::Error`.
74-
pub(crate) fn into_io_error(self) -> std::io::Error {
75-
match Arc::try_unwrap(self.kind) {
76-
Ok(ErrorKind::Io(io_error)) => io_error,
77-
Ok(other_error_kind) => {
78-
let error: Error = other_error_kind.into();
79-
std::io::Error::new(std::io::ErrorKind::Other, Box::new(error))
80-
}
81-
Err(e) => std::io::Error::new(std::io::ErrorKind::Other, Box::new(Error::new(e))),
82-
}
83-
}
84-
8567
/// Whether this error is an "ns not found" error or not.
8668
pub(crate) fn is_ns_not_found(&self) -> bool {
8769
matches!(self.kind.as_ref(), ErrorKind::CommandError(err) if err.code == 26)
@@ -129,10 +111,13 @@ impl Error {
129111

130112
/// Whether an error originated from the server.
131113
pub(crate) fn is_server_error(&self) -> bool {
132-
matches!(self.kind.as_ref(), ErrorKind::AuthenticationError { .. }
133-
| ErrorKind::BulkWriteError(_)
134-
| ErrorKind::CommandError(_)
135-
| ErrorKind::WriteError(_))
114+
matches!(
115+
self.kind.as_ref(),
116+
ErrorKind::AuthenticationError { .. }
117+
| ErrorKind::BulkWriteError(_)
118+
| ErrorKind::CommandError(_)
119+
| ErrorKind::WriteError(_)
120+
)
136121
}
137122

138123
/// Returns the labels for this error.
@@ -153,7 +138,9 @@ impl Error {
153138

154139
/// Whether this error contains the specified label.
155140
pub fn contains_label<T: AsRef<str>>(&self, label: T) -> bool {
156-
self.labels().iter().any(|actual_label| actual_label.as_str() == label.as_ref())
141+
self.labels()
142+
.iter()
143+
.any(|actual_label| actual_label.as_str() == label.as_ref())
157144
}
158145

159146
/// Returns a copy of this Error with the specified label added.
@@ -334,7 +321,7 @@ pub enum ErrorKind {
334321
/// A timeout occurred before a Tokio task could be completed.
335322
#[cfg(feature = "tokio-runtime")]
336323
#[error(display = "{}", _0)]
337-
TokioTimeoutElapsed(#[error(source)] tokio::time::Elapsed),
324+
TokioTimeoutElapsed(#[error(source)] tokio::time::error::Elapsed),
338325

339326
#[error(display = "{}", _0)]
340327
RustlsConfig(#[error(source)] rustls::TLSError),
@@ -372,7 +359,10 @@ impl ErrorKind {
372359
}
373360

374361
pub(crate) fn is_network_error(&self) -> bool {
375-
matches!(self, ErrorKind::Io(..) | ErrorKind::ConnectionPoolClearedError { .. })
362+
matches!(
363+
self,
364+
ErrorKind::Io(..) | ErrorKind::ConnectionPoolClearedError { .. }
365+
)
376366
}
377367

378368
/// Gets the code/message tuple from this error, if applicable. In the case of write errors, the
@@ -396,13 +386,14 @@ impl ErrorKind {
396386
pub(crate) fn code_name(&self) -> Option<&str> {
397387
match self {
398388
ErrorKind::CommandError(ref cmd_err) => Some(cmd_err.code_name.as_str()),
399-
ErrorKind::WriteError(ref failure) => {
400-
match failure {
401-
WriteFailure::WriteConcernError(ref wce) => Some(wce.code_name.as_str()),
402-
WriteFailure::WriteError(ref we) => we.code_name.as_deref(),
403-
}
404-
}
405-
ErrorKind::BulkWriteError(ref bwe) => bwe.write_concern_error.as_ref().map(|wce| wce.code_name.as_str()),
389+
ErrorKind::WriteError(ref failure) => match failure {
390+
WriteFailure::WriteConcernError(ref wce) => Some(wce.code_name.as_str()),
391+
WriteFailure::WriteError(ref we) => we.code_name.as_deref(),
392+
},
393+
ErrorKind::BulkWriteError(ref bwe) => bwe
394+
.write_concern_error
395+
.as_ref()
396+
.map(|wce| wce.code_name.as_str()),
406397
_ => None,
407398
}
408399
}

src/runtime/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ impl AsyncRuntime {
118118
pub(crate) async fn delay_for(self, delay: Duration) {
119119
#[cfg(feature = "tokio-runtime")]
120120
{
121-
tokio::time::delay_for(delay).await
121+
tokio::time::sleep(delay).await
122122
}
123123

124124
#[cfg(feature = "async-std-runtime")]

0 commit comments

Comments
 (0)