Skip to content

Commit 7b0e259

Browse files
authored
RUST-1337 Use tokio's AsyncRead and AsyncWrite traits (#668)
This fixes a performance regression that resulted from upgrading tokio-rustls and/or rustls.
1 parent 6d89243 commit 7b0e259

File tree

12 files changed

+110
-237
lines changed

12 files changed

+110
-237
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ chrono = "0.4.7"
5959
derivative = "2.1.1"
6060
flate2 = { version = "1.0", optional = true }
6161
futures-core = "0.3.14"
62-
futures-io = "0.3.14"
6362
futures-util = { version = "0.3.14", features = ["io"] }
6463
futures-executor = "0.3.14"
6564
hex = "0.4.0"

src/bson_util/mod.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
use std::{convert::TryFrom, io::Read, time::Duration};
1+
use std::{
2+
convert::TryFrom,
3+
io::{Read, Write},
4+
time::Duration,
5+
};
26

37
use bson::RawBsonRef;
48
use serde::{de::Error as SerdeDeError, ser, Deserialize, Deserializer, Serialize, Serializer};
59

610
use crate::{
711
bson::{doc, Bson, Document},
812
error::{ErrorKind, Result},
9-
runtime::{SyncLittleEndianRead, SyncLittleEndianWrite},
13+
runtime::SyncLittleEndianRead,
1014
};
1115

1216
/// Coerce numeric types into an `i64` if it would be lossless to do so. If this Bson is not numeric
@@ -203,10 +207,10 @@ fn num_decimal_digits(mut n: usize) -> u64 {
203207

204208
/// Read a document's raw BSON bytes from the provided reader.
205209
pub(crate) fn read_document_bytes<R: Read>(mut reader: R) -> Result<Vec<u8>> {
206-
let length = reader.read_i32()?;
210+
let length = reader.read_i32_sync()?;
207211

208212
let mut bytes = Vec::with_capacity(length as usize);
209-
bytes.write_i32(length)?;
213+
bytes.write_all(&length.to_le_bytes())?;
210214

211215
reader.take(length as u64 - 4).read_to_end(&mut bytes)?;
212216

src/cmap/conn/wire/header.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
use futures_io::{AsyncRead, AsyncWrite};
2-
use futures_util::AsyncWriteExt;
1+
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
32

4-
use crate::{
5-
error::{ErrorKind, Result},
6-
runtime::AsyncLittleEndianRead,
7-
};
3+
use crate::error::{ErrorKind, Result};
84

95
/// The wire protocol op codes.
106
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
@@ -54,11 +50,13 @@ impl Header {
5450
}
5551

5652
/// Reads bytes from `r` and deserializes them into a header.
57-
pub(crate) async fn read_from<R: AsyncRead + Unpin + Send>(reader: &mut R) -> Result<Self> {
58-
let length = reader.read_i32().await?;
59-
let request_id = reader.read_i32().await?;
60-
let response_to = reader.read_i32().await?;
61-
let op_code = OpCode::from_i32(reader.read_i32().await?)?;
53+
pub(crate) async fn read_from<R: tokio::io::AsyncRead + Unpin + Send>(
54+
reader: &mut R,
55+
) -> Result<Self> {
56+
let length = reader.read_i32_le().await?;
57+
let request_id = reader.read_i32_le().await?;
58+
let response_to = reader.read_i32_le().await?;
59+
let op_code = OpCode::from_i32(reader.read_i32_le().await?)?;
6260
Ok(Self {
6361
length,
6462
request_id,

src/cmap/conn/wire/message.rs

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
use std::io::Read;
22

33
use bitflags::bitflags;
4-
use futures_io::AsyncWrite;
5-
use futures_util::{
6-
io::{BufReader, BufWriter},
7-
AsyncReadExt,
8-
AsyncWriteExt,
9-
};
4+
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
105

116
use super::header::{Header, OpCode};
127
use crate::{
@@ -16,7 +11,7 @@ use crate::{
1611
Command,
1712
},
1813
error::{Error, ErrorKind, Result},
19-
runtime::{AsyncLittleEndianWrite, AsyncStream, SyncLittleEndianRead},
14+
runtime::{AsyncStream, SyncLittleEndianRead},
2015
};
2116

2217
use crate::compression::{Compressor, Decoder};
@@ -139,7 +134,7 @@ impl Message {
139134
let mut reader = buf.as_slice();
140135

141136
// Read original opcode (should be OP_MSG)
142-
let original_opcode = reader.read_i32()?;
137+
let original_opcode = reader.read_i32_sync()?;
143138
if original_opcode != OpCode::Message as i32 {
144139
return Err(ErrorKind::InvalidResponse {
145140
message: format!(
@@ -152,10 +147,10 @@ impl Message {
152147
}
153148

154149
// Read uncompressed size
155-
let uncompressed_size = reader.read_i32()?;
150+
let uncompressed_size = reader.read_i32_sync()?;
156151

157152
// Read compressor id
158-
let compressor_id: u8 = reader.read_u8()?;
153+
let compressor_id: u8 = reader.read_u8_sync()?;
159154

160155
// Get decoder
161156
let decoder = Decoder::from_u8(compressor_id)?;
@@ -188,7 +183,7 @@ impl Message {
188183
mut length_remaining: i32,
189184
header: &Header,
190185
) -> Result<Self> {
191-
let flags = MessageFlags::from_bits_truncate(reader.read_u32()?);
186+
let flags = MessageFlags::from_bits_truncate(reader.read_u32_sync()?);
192187
length_remaining -= std::mem::size_of::<u32>() as i32;
193188

194189
let mut count_reader = SyncCountReader::new(&mut reader);
@@ -203,7 +198,7 @@ impl Message {
203198
let mut checksum = None;
204199

205200
if length_remaining == 4 && flags.contains(MessageFlags::CHECKSUM_PRESENT) {
206-
checksum = Some(reader.read_u32()?);
201+
checksum = Some(reader.read_u32_sync()?);
207202
} else if length_remaining != 0 {
208203
return Err(ErrorKind::InvalidResponse {
209204
message: format!(
@@ -251,11 +246,11 @@ impl Message {
251246
};
252247

253248
header.write_to(&mut writer).await?;
254-
writer.write_u32(self.flags.bits()).await?;
249+
writer.write_u32_le(self.flags.bits()).await?;
255250
writer.write_all(&sections_bytes).await?;
256251

257252
if let Some(checksum) = self.checksum {
258-
writer.write_u32(checksum).await?;
253+
writer.write_u32_le(checksum).await?;
259254
}
260255

261256
writer.flush().await?;
@@ -302,9 +297,9 @@ impl Message {
302297
// Write header
303298
header.write_to(&mut writer).await?;
304299
// Write original (pre-compressed) opcode (always OP_MSG)
305-
writer.write_i32(OpCode::Message as i32).await?;
300+
writer.write_i32_le(OpCode::Message as i32).await?;
306301
// Write uncompressed size
307-
writer.write_i32(uncompressed_len as i32).await?;
302+
writer.write_i32_le(uncompressed_len as i32).await?;
308303
// Write compressor id
309304
writer.write_u8(compressor_id).await?;
310305
// Write compressed message
@@ -341,15 +336,15 @@ pub(crate) enum MessageSection {
341336
impl MessageSection {
342337
/// Reads bytes from `reader` and deserializes them into a MessageSection.
343338
fn read<R: Read>(reader: &mut R) -> Result<Self> {
344-
let payload_type = reader.read_u8()?;
339+
let payload_type = reader.read_u8_sync()?;
345340

346341
if payload_type == 0 {
347342
return Ok(MessageSection::Document(bson_util::read_document_bytes(
348343
reader,
349344
)?));
350345
}
351346

352-
let size = reader.read_i32()?;
347+
let size = reader.read_i32_sync()?;
353348
let mut length_remaining = size - std::mem::size_of::<i32>() as i32;
354349

355350
let mut identifier = String::new();
@@ -397,7 +392,7 @@ impl MessageSection {
397392
// Write payload type.
398393
writer.write_u8(1).await?;
399394

400-
writer.write_i32(*size).await?;
395+
writer.write_i32_le(*size).await?;
401396
super::util::write_cstring(writer, identifier).await?;
402397

403398
for doc in documents {

src/cmap/conn/wire/util.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@ use std::{
33
sync::atomic::{AtomicI32, Ordering},
44
};
55

6-
use futures_io::{self, AsyncWrite};
7-
use futures_util::AsyncWriteExt;
86
use lazy_static::lazy_static;
7+
use tokio::io::{AsyncWrite, AsyncWriteExt};
98

109
use crate::error::Result;
1110

src/runtime/async_read_ext.rs

Lines changed: 0 additions & 56 deletions
This file was deleted.

src/runtime/async_write_ext.rs

Lines changed: 0 additions & 51 deletions
This file was deleted.

src/runtime/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
mod acknowledged_message;
2-
mod async_read_ext;
3-
mod async_write_ext;
42
mod http;
53
#[cfg(feature = "async-std-runtime")]
64
mod interval;
75
mod join_handle;
86
mod resolver;
97
mod stream;
8+
mod sync_read_ext;
109
#[cfg(feature = "openssl-tls")]
1110
mod tls_openssl;
1211
#[cfg_attr(feature = "openssl-tls", allow(unused))]
@@ -17,11 +16,10 @@ use std::{future::Future, net::SocketAddr, time::Duration};
1716

1817
pub(crate) use self::{
1918
acknowledged_message::AcknowledgedMessage,
20-
async_read_ext::{AsyncLittleEndianRead, SyncLittleEndianRead},
21-
async_write_ext::{AsyncLittleEndianWrite, SyncLittleEndianWrite},
2219
join_handle::AsyncJoinHandle,
2320
resolver::AsyncResolver,
2421
stream::AsyncStream,
22+
sync_read_ext::SyncLittleEndianRead,
2523
worker_handle::{WorkerHandle, WorkerHandleListener},
2624
};
2725
use crate::{error::Result, options::ServerAddress};

0 commit comments

Comments
 (0)