Skip to content

Commit 2b5da04

Browse files
authored
RUST-1337 Use tokio's AsyncRead and AsyncWrite traits (#669)
This fixes a performance regression that resulted from upgrading tokio-rustls and/or rustls.
1 parent ab491ae commit 2b5da04

File tree

12 files changed

+110
-237
lines changed

12 files changed

+110
-237
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ chrono = "0.4.7"
5959
derivative = "2.1.1"
6060
flate2 = { version = "1.0", optional = true }
6161
futures-core = "0.3.14"
62-
futures-io = "0.3.14"
6362
futures-util = { version = "0.3.14", features = ["io"] }
6463
futures-executor = "0.3.14"
6564
hex = "0.4.0"

src/bson_util/mod.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
use std::{convert::TryFrom, io::Read, time::Duration};
1+
use std::{
2+
convert::TryFrom,
3+
io::{Read, Write},
4+
time::Duration,
5+
};
26

37
use bson::RawBsonRef;
48
use serde::{de::Error as SerdeDeError, ser, Deserialize, Deserializer, Serialize, Serializer};
59

610
use crate::{
711
bson::{doc, Bson, Document},
812
error::{ErrorKind, Result},
9-
runtime::{SyncLittleEndianRead, SyncLittleEndianWrite},
13+
runtime::SyncLittleEndianRead,
1014
};
1115

1216
/// Coerce numeric types into an `i64` if it would be lossless to do so. If this Bson is not numeric
@@ -203,10 +207,10 @@ fn num_decimal_digits(mut n: usize) -> u64 {
203207

204208
/// Read a document's raw BSON bytes from the provided reader.
205209
pub(crate) fn read_document_bytes<R: Read>(mut reader: R) -> Result<Vec<u8>> {
206-
let length = reader.read_i32()?;
210+
let length = reader.read_i32_sync()?;
207211

208212
let mut bytes = Vec::with_capacity(length as usize);
209-
bytes.write_i32(length)?;
213+
bytes.write_all(&length.to_le_bytes())?;
210214

211215
reader.take(length as u64 - 4).read_to_end(&mut bytes)?;
212216

src/cmap/conn/wire/header.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
use futures_io::{AsyncRead, AsyncWrite};
2-
use futures_util::AsyncWriteExt;
1+
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
32

4-
use crate::{
5-
error::{ErrorKind, Result},
6-
runtime::AsyncLittleEndianRead,
7-
};
3+
use crate::error::{ErrorKind, Result};
84

95
/// The wire protocol op codes.
106
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
@@ -54,11 +50,13 @@ impl Header {
5450
}
5551

5652
/// Reads bytes from `r` and deserializes them into a header.
57-
pub(crate) async fn read_from<R: AsyncRead + Unpin + Send>(reader: &mut R) -> Result<Self> {
58-
let length = reader.read_i32().await?;
59-
let request_id = reader.read_i32().await?;
60-
let response_to = reader.read_i32().await?;
61-
let op_code = OpCode::from_i32(reader.read_i32().await?)?;
53+
pub(crate) async fn read_from<R: tokio::io::AsyncRead + Unpin + Send>(
54+
reader: &mut R,
55+
) -> Result<Self> {
56+
let length = reader.read_i32_le().await?;
57+
let request_id = reader.read_i32_le().await?;
58+
let response_to = reader.read_i32_le().await?;
59+
let op_code = OpCode::from_i32(reader.read_i32_le().await?)?;
6260
Ok(Self {
6361
length,
6462
request_id,

src/cmap/conn/wire/message.rs

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
use std::io::Read;
22

33
use bitflags::bitflags;
4-
use futures_io::AsyncWrite;
5-
use futures_util::{
6-
io::{BufReader, BufWriter},
7-
AsyncReadExt,
8-
AsyncWriteExt,
9-
};
4+
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
105

116
use super::header::{Header, OpCode};
127
use crate::{
@@ -16,7 +11,7 @@ use crate::{
1611
Command,
1712
},
1813
error::{Error, ErrorKind, Result},
19-
runtime::{AsyncLittleEndianWrite, AsyncStream, SyncLittleEndianRead},
14+
runtime::{AsyncStream, SyncLittleEndianRead},
2015
};
2116

2217
use crate::compression::{Compressor, Decoder};
@@ -129,7 +124,7 @@ impl Message {
129124
let mut reader = buf.as_slice();
130125

131126
// Read original opcode (should be OP_MSG)
132-
let original_opcode = reader.read_i32()?;
127+
let original_opcode = reader.read_i32_sync()?;
133128
if original_opcode != OpCode::Message as i32 {
134129
return Err(ErrorKind::InvalidResponse {
135130
message: format!(
@@ -142,10 +137,10 @@ impl Message {
142137
}
143138

144139
// Read uncompressed size
145-
let uncompressed_size = reader.read_i32()?;
140+
let uncompressed_size = reader.read_i32_sync()?;
146141

147142
// Read compressor id
148-
let compressor_id: u8 = reader.read_u8()?;
143+
let compressor_id: u8 = reader.read_u8_sync()?;
149144

150145
// Get decoder
151146
let decoder = Decoder::from_u8(compressor_id)?;
@@ -178,7 +173,7 @@ impl Message {
178173
mut length_remaining: i32,
179174
header: &Header,
180175
) -> Result<Self> {
181-
let flags = MessageFlags::from_bits_truncate(reader.read_u32()?);
176+
let flags = MessageFlags::from_bits_truncate(reader.read_u32_sync()?);
182177
length_remaining -= std::mem::size_of::<u32>() as i32;
183178

184179
let mut count_reader = SyncCountReader::new(&mut reader);
@@ -193,7 +188,7 @@ impl Message {
193188
let mut checksum = None;
194189

195190
if length_remaining == 4 && flags.contains(MessageFlags::CHECKSUM_PRESENT) {
196-
checksum = Some(reader.read_u32()?);
191+
checksum = Some(reader.read_u32_sync()?);
197192
} else if length_remaining != 0 {
198193
return Err(ErrorKind::InvalidResponse {
199194
message: format!(
@@ -241,11 +236,11 @@ impl Message {
241236
};
242237

243238
header.write_to(&mut writer).await?;
244-
writer.write_u32(self.flags.bits()).await?;
239+
writer.write_u32_le(self.flags.bits()).await?;
245240
writer.write_all(&sections_bytes).await?;
246241

247242
if let Some(checksum) = self.checksum {
248-
writer.write_u32(checksum).await?;
243+
writer.write_u32_le(checksum).await?;
249244
}
250245

251246
writer.flush().await?;
@@ -292,9 +287,9 @@ impl Message {
292287
// Write header
293288
header.write_to(&mut writer).await?;
294289
// Write original (pre-compressed) opcode (always OP_MSG)
295-
writer.write_i32(OpCode::Message as i32).await?;
290+
writer.write_i32_le(OpCode::Message as i32).await?;
296291
// Write uncompressed size
297-
writer.write_i32(uncompressed_len as i32).await?;
292+
writer.write_i32_le(uncompressed_len as i32).await?;
298293
// Write compressor id
299294
writer.write_u8(compressor_id).await?;
300295
// Write compressed message
@@ -329,15 +324,15 @@ pub(crate) enum MessageSection {
329324
impl MessageSection {
330325
/// Reads bytes from `reader` and deserializes them into a MessageSection.
331326
fn read<R: Read>(reader: &mut R) -> Result<Self> {
332-
let payload_type = reader.read_u8()?;
327+
let payload_type = reader.read_u8_sync()?;
333328

334329
if payload_type == 0 {
335330
return Ok(MessageSection::Document(bson_util::read_document_bytes(
336331
reader,
337332
)?));
338333
}
339334

340-
let size = reader.read_i32()?;
335+
let size = reader.read_i32_sync()?;
341336
let mut length_remaining = size - std::mem::size_of::<i32>() as i32;
342337

343338
let mut identifier = String::new();
@@ -385,7 +380,7 @@ impl MessageSection {
385380
// Write payload type.
386381
writer.write_u8(1).await?;
387382

388-
writer.write_i32(*size).await?;
383+
writer.write_i32_le(*size).await?;
389384
super::util::write_cstring(writer, identifier).await?;
390385

391386
for doc in documents {

src/cmap/conn/wire/util.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@ use std::{
33
sync::atomic::{AtomicI32, Ordering},
44
};
55

6-
use futures_io::{self, AsyncWrite};
7-
use futures_util::AsyncWriteExt;
86
use lazy_static::lazy_static;
7+
use tokio::io::{AsyncWrite, AsyncWriteExt};
98

109
use crate::error::Result;
1110

src/runtime/async_read_ext.rs

Lines changed: 0 additions & 56 deletions
This file was deleted.

src/runtime/async_write_ext.rs

Lines changed: 0 additions & 51 deletions
This file was deleted.

src/runtime/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
mod acknowledged_message;
2-
mod async_read_ext;
3-
mod async_write_ext;
42
mod http;
53
#[cfg(feature = "async-std-runtime")]
64
mod interval;
75
mod join_handle;
86
mod resolver;
97
mod stream;
8+
mod sync_read_ext;
109
#[cfg(feature = "openssl-tls")]
1110
mod tls_openssl;
1211
#[cfg_attr(feature = "openssl-tls", allow(unused))]
@@ -16,11 +15,10 @@ use std::{future::Future, net::SocketAddr, time::Duration};
1615

1716
pub(crate) use self::{
1817
acknowledged_message::AcknowledgedMessage,
19-
async_read_ext::{AsyncLittleEndianRead, SyncLittleEndianRead},
20-
async_write_ext::{AsyncLittleEndianWrite, SyncLittleEndianWrite},
2118
join_handle::AsyncJoinHandle,
2219
resolver::AsyncResolver,
2320
stream::AsyncStream,
21+
sync_read_ext::SyncLittleEndianRead,
2422
};
2523
use crate::{error::Result, options::ServerAddress};
2624
pub(crate) use http::HttpClient;

0 commit comments

Comments
 (0)