Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ in-use-encryption-unstable = ["in-use-encryption"]
# The tracing API is unstable and may have backwards-incompatible changes in minor version updates.
# TODO: pending https://github.com/tokio-rs/tracing/issues/2036 stop depending directly on log.
tracing-unstable = ["dep:tracing", "dep:log"]
fuzzing = ["dep:arbitrary", "arbitrary/derive"]

[dependencies]
async-trait = "0.1.42"
base64 = "0.13.0"
bitflags = "1.1.0"
bitflags = { version = "1.1.0" }
arbitrary = { version = "1.3", optional = true }
byteorder = { version = "1.4" }
bson = { git = "https://github.com/mongodb/bson-rust", branch = "main", version = "2.11.0" }
chrono = { version = "0.4.7", default-features = false, features = [
"clock",
Expand Down
4 changes: 4 additions & 0 deletions fuzz/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
target
corpus
artifacts
coverage
18 changes: 18 additions & 0 deletions fuzz/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "mongodb-fuzz"
version = "0.0.0"
publish = false
edition = "2021"

[package.metadata]
cargo-fuzz = true

[dependencies]
libfuzzer-sys = "0.4"
mongodb = { path = "..", features = ["fuzzing"] }

[[bin]]
name = "header_length"
path = "fuzz_targets/header_length.rs"
test = false
doc = false
15 changes: 15 additions & 0 deletions fuzz/fuzz_targets/header_length.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#![no_main]
use libfuzzer_sys::fuzz_target;
use mongodb::cmap::conn::wire::{header::Header, message::Message};

fuzz_target!(|data: &[u8]| {
if data.len() < Header::LENGTH {
return;
}
if let Ok(header) = Header::from_slice(data) {
let data = &data[Header::LENGTH..];
if let Ok(message) = Message::read_from_slice(data, header) {
let _ = message;
}
}
});
6 changes: 6 additions & 0 deletions src/cmap.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
#[cfg(test)]
pub(crate) mod test;

#[cfg(feature = "fuzzing")]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow missing just so we don't see warning when we run the fuzz test

#[allow(missing_docs)]
pub mod conn;

#[cfg(not(feature = "fuzzing"))]
pub(crate) mod conn;

mod connection_requester;
pub(crate) mod establish;
mod manager;
Expand Down
6 changes: 6 additions & 0 deletions src/cmap/conn.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
mod command;
pub(crate) mod pooled;
mod stream_description;

#[cfg(feature = "fuzzing")]
#[allow(missing_docs)]
pub mod wire;

#[cfg(not(feature = "fuzzing"))]
pub(crate) mod wire;

use std::{sync::Arc, time::Instant};
Expand Down
26 changes: 22 additions & 4 deletions src/cmap/conn/wire.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
#[cfg(feature = "fuzzing")]
#[allow(missing_docs)]
pub mod header;
#[cfg(not(feature = "fuzzing"))]
mod header;
#[cfg(feature = "fuzzing")]
#[allow(missing_docs)]
pub mod message;
#[cfg(not(feature = "fuzzing"))]
pub(crate) mod message;
#[cfg(feature = "fuzzing")]
#[allow(missing_docs)]
pub mod util;
#[cfg(not(feature = "fuzzing"))]
mod util;

pub(crate) use self::{
message::{Message, MessageFlags},
util::next_request_id,
};
pub(crate) use self::util::next_request_id;

#[cfg(feature = "fuzzing")]
pub use self::message::Message;

#[cfg(feature = "fuzzing")]
pub use crate::fuzz::message_flags::MessageFlags;

#[cfg(not(feature = "fuzzing"))]
pub use message::{Message, MessageFlags};
73 changes: 66 additions & 7 deletions src/cmap/conn/wire/header.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use crate::error::{ErrorKind, Result};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};

use crate::error::{ErrorKind, Result};
#[cfg(feature = "fuzzing")]
use arbitrary::Arbitrary;
#[cfg(feature = "fuzzing")]
use byteorder::{LittleEndian, ReadBytesExt};
#[cfg(feature = "fuzzing")]
use std::io::Cursor;

/// The wire protocol op codes.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(crate) enum OpCode {
#[cfg_attr(feature = "fuzzing", derive(Arbitrary))]
pub enum OpCode {
Reply = 1,
Query = 2004,
Message = 2013,
Expand All @@ -13,7 +20,7 @@ pub(crate) enum OpCode {

impl OpCode {
/// Attempt to infer the op code based on the numeric value.
fn from_i32(i: i32) -> Result<Self> {
pub fn from_i32(i: i32) -> Result<Self> {
match i {
1 => Ok(OpCode::Reply),
2004 => Ok(OpCode::Query),
Expand All @@ -28,18 +35,71 @@ impl OpCode {
}

/// The header for any wire protocol message.
#[derive(Debug)]
pub(crate) struct Header {
#[derive(Debug, Clone)]
#[cfg_attr(feature = "fuzzing", derive(Arbitrary))]
pub struct Header {
pub length: i32,
pub request_id: i32,
pub response_to: i32,
pub op_code: OpCode,
}

impl Header {
#[cfg(feature = "fuzzing")]
pub const LENGTH: usize = 4 * std::mem::size_of::<i32>();

#[cfg(not(feature = "fuzzing"))]
pub(crate) const LENGTH: usize = 4 * std::mem::size_of::<i32>();

/// Serializes the Header and writes the bytes to `w`.
// generates a Header from a randomly generated slice of bytes, as long as the slice is at least
// 16 bytes long this is used for fuzzing
#[cfg(feature = "fuzzing")]
pub fn from_slice(data: &[u8]) -> Result<Self> {
if data.len() < Self::LENGTH {
return Err(ErrorKind::InvalidResponse {
message: format!(
"Header requires {} bytes but only got {}",
Self::LENGTH,
data.len()
),
}
.into());
}
let mut cursor = Cursor::new(data);

let length = ReadBytesExt::read_i32::<LittleEndian>(&mut cursor).map_err(|e| {
ErrorKind::InvalidResponse {
message: format!("Failed to read length: {}", e),
}
})?;

let request_id = ReadBytesExt::read_i32::<LittleEndian>(&mut cursor).map_err(|e| {
ErrorKind::InvalidResponse {
message: format!("Failed to read request_id: {}", e),
}
})?;

let response_to = ReadBytesExt::read_i32::<LittleEndian>(&mut cursor).map_err(|e| {
ErrorKind::InvalidResponse {
message: format!("Failed to read response_to: {}", e),
}
})?;

let op_code =
OpCode::from_i32(ReadBytesExt::read_i32::<LittleEndian>(&mut cursor).map_err(
|e| ErrorKind::InvalidResponse {
message: format!("Failed to read op_code: {}", e),
},
)?)?;

Ok(Self {
length,
request_id,
response_to,
op_code,
})
}

pub(crate) async fn write_to<W: AsyncWrite + Unpin>(&self, stream: &mut W) -> Result<()> {
stream.write_all(&self.length.to_le_bytes()).await?;
stream.write_all(&self.request_id.to_le_bytes()).await?;
Expand All @@ -51,7 +111,6 @@ impl Header {
Ok(())
}

/// Reads bytes from `r` and deserializes them into a header.
pub(crate) async fn read_from<R: tokio::io::AsyncRead + Unpin + Send>(
reader: &mut R,
) -> Result<Self> {
Expand Down
Loading