Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ mongodb-internal-macros = { path = "macros", version = "3.1.0" }
num_cpus = { version = "1.13.1", optional = true }
openssl = { version = "0.10.38", optional = true }
openssl-probe = { version = "0.1.5", optional = true }
pem = { version = "3.0.4" }
percent-encoding = "2.0.0"
pkcs8 = { version = "0.10.2", features = ["encryption", "pkcs5"] }
rand = { version = "0.8.3", features = ["small_rng"] }
rayon = { version = "1.5.3", optional = true }
rustc_version_runtime = "0.3.0"
Expand Down
22 changes: 22 additions & 0 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,10 @@ pub struct TlsOptions {
/// The default value is to error on invalid hostnames.
#[cfg(feature = "openssl-tls")]
pub allow_invalid_hostnames: Option<bool>,

/// If set, the key in `cert_key_file_path` must be encrypted with this password. Only
/// supported with `rustls`.
pub tls_certificate_key_file_password: Option<Vec<u8>>,
}

impl TlsOptions {
Expand Down Expand Up @@ -2126,6 +2130,24 @@ impl ConnectionString {
))
}
},
"tlscertificatekeyfilepassword" => match &mut self.tls {
Some(Tls::Disabled) => {
return Err(ErrorKind::InvalidArgument {
message: "'tlsCertificateKeyFilePassword' can't be set if tls=false".into(),
}
.into());
}
Some(Tls::Enabled(options)) => {
options.tls_certificate_key_file_password = Some(value.as_bytes().to_vec());
}
None => {
self.tls = Some(Tls::Enabled(
TlsOptions::builder()
.tls_certificate_key_file_password(value.as_bytes().to_vec())
.build(),
))
}
},
"uuidrepresentation" => match value.to_lowercase().as_str() {
"csharplegacy" => self.uuid_representation = Some(UuidRepresentation::CSharpLegacy),
"javalegacy" => self.uuid_representation = Some(UuidRepresentation::JavaLegacy),
Expand Down
1 change: 1 addition & 0 deletions src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod acknowledged_message;
))]
mod http;
mod join_handle;
mod pem;
#[cfg(any(feature = "in-use-encryption", test))]
pub(crate) mod process;
#[cfg(feature = "dns-resolver")]
Expand Down
30 changes: 30 additions & 0 deletions src/runtime/pem.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::error::{ErrorKind, Result};

pub(crate) fn decrypt_private_key(pem_data: &[u8], password: &[u8]) -> Result<Vec<u8>> {
let pems = pem::parse_many(&pem_data).map_err(|error| ErrorKind::InvalidTlsConfig {
message: format!("Could not parse pemfile: {}", error),
})?;
let mut iter = pems
.into_iter()
.filter(|pem| pem.tag() == "ENCRYPTED PRIVATE KEY");
let encrypted_bytes = match iter.next() {
Some(pem) => pem.into_contents(),
None => {
return Err(ErrorKind::InvalidTlsConfig {
message: "No encrypted private keys found".into(),
}
.into())
}
};
let encrypted_key = pkcs8::EncryptedPrivateKeyInfo::try_from(encrypted_bytes.as_slice())
.map_err(|error| ErrorKind::InvalidTlsConfig {
message: format!("Invalid encrypted private key: {}", error),
})?;
let decrypted_key =
encrypted_key
.decrypt(password)
.map_err(|error| ErrorKind::InvalidTlsConfig {
message: format!("Failed to decrypt private key: {}", error),
})?;
Ok(decrypted_key.as_bytes().to_vec())
}
43 changes: 33 additions & 10 deletions src/runtime/tls_openssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use crate::{
error::{Error, ErrorKind, Result},
};

use super::pem::decrypt_private_key;

pub(super) type TlsStream = SslStream<TcpStream>;

/// Configuration required to use TLS. Creating this is expensive, so its best to cache this value
Expand All @@ -26,16 +28,19 @@ impl TlsConfig {
/// Create a new `TlsConfig` from the provided options from the user.
/// This operation is expensive, so the resultant `TlsConfig` should be cached.
pub(crate) fn new(options: TlsOptions) -> Result<TlsConfig> {
if options.tls_certificate_key_file_password.is_some() {
return Err(ErrorKind::InvalidArgument {
message: "'tlsCertificateKeyFilePassword' can't be used with 'openssl-tls'".into(),
}
.into());
}

let verify_hostname = match options.allow_invalid_hostnames {
Some(b) => !b,
None => true,
};

let connector = make_openssl_connector(options).map_err(|e| {
Error::from(ErrorKind::InvalidTlsConfig {
message: e.to_string(),
})
})?;
let connector = make_openssl_connector(options)?;

Ok(TlsConfig {
connector,
Expand Down Expand Up @@ -66,25 +71,43 @@ pub(super) async fn tls_connect(
Ok(stream)
}

fn make_openssl_connector(cfg: TlsOptions) -> std::result::Result<SslConnector, ErrorStack> {
let mut builder = SslConnector::builder(SslMethod::tls_client())?;
fn make_openssl_connector(cfg: TlsOptions) -> Result<SslConnector> {
let openssl_err = |e: ErrorStack| {
Error::from(ErrorKind::InvalidTlsConfig {
message: e.to_string(),
})
};

let mut builder = SslConnector::builder(SslMethod::tls_client()).map_err(openssl_err)?;

let TlsOptions {
allow_invalid_certificates,
ca_file_path,
cert_key_file_path,
allow_invalid_hostnames: _,
tls_certificate_key_file_password,
} = cfg;

if let Some(true) = allow_invalid_certificates {
builder.set_verify(SslVerifyMode::NONE);
}
if let Some(path) = ca_file_path {
builder.set_ca_file(path)?;
builder.set_ca_file(path).map_err(openssl_err)?;
}
if let Some(path) = cert_key_file_path {
builder.set_certificate_file(path.clone(), SslFiletype::PEM)?;
builder.set_private_key_file(path, SslFiletype::PEM)?;
builder
.set_certificate_file(path.clone(), SslFiletype::PEM)
.map_err(openssl_err)?;
if let Some(key_pw) = tls_certificate_key_file_password {
let contents = std::fs::read(&path)?;
let key_bytes = decrypt_private_key(&contents, &key_pw)?;
let key = openssl::pkey::PKey::private_key_from_pem(&key_bytes).map_err(openssl_err)?;
builder.set_private_key(&key).map_err(openssl_err)?;
} else {
builder
.set_private_key_file(path, SslFiletype::PEM)
.map_err(openssl_err)?;
}
}

Ok(builder.build())
Expand Down
46 changes: 27 additions & 19 deletions src/runtime/tls_rustls.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{
convert::TryFrom,
fs::File,
io::{BufReader, Seek},
io::{BufReader, Read, Seek},
sync::Arc,
time::SystemTime,
};
Expand All @@ -23,6 +23,8 @@ use crate::{
error::{ErrorKind, Result},
};

use super::pem::decrypt_private_key;

pub(super) type TlsStream = tokio_rustls::client::TlsStream<TcpStream>;

/// Configuration required to use TLS. Creating this is expensive, so its best to cache this value
Expand Down Expand Up @@ -103,26 +105,32 @@ fn make_rustls_config(cfg: TlsOptions) -> Result<rustls::ClientConfig> {
};

file.rewind()?;
let key = loop {
match read_one(&mut file) {
Ok(Some(Item::PKCS8Key(bytes))) | Ok(Some(Item::RSAKey(bytes))) => {
break rustls::PrivateKey(bytes)
}
Ok(Some(_)) => continue,
Ok(None) => {
return Err(ErrorKind::InvalidTlsConfig {
message: format!("No PEM-encoded keys in {}", path.display()),
let key = if let Some(key_pw) = cfg.tls_certificate_key_file_password.as_deref() {
let mut contents = vec![];
file.read_to_end(&mut contents)?;
rustls::PrivateKey(decrypt_private_key(&contents, key_pw)?)
} else {
loop {
match read_one(&mut file) {
Ok(Some(Item::PKCS8Key(bytes))) | Ok(Some(Item::RSAKey(bytes))) => {
break rustls::PrivateKey(bytes)
}
.into())
}
Err(_) => {
return Err(ErrorKind::InvalidTlsConfig {
message: format!(
"Unable to parse PEM-encoded item from {}",
path.display()
),
Ok(Some(_)) => continue,
Ok(None) => {
return Err(ErrorKind::InvalidTlsConfig {
message: format!("No PEM-encoded keys in {}", path.display()),
}
.into())
}
Err(_) => {
return Err(ErrorKind::InvalidTlsConfig {
message: format!(
"Unable to parse PEM-encoded item from {}",
path.display()
),
}
.into())
}
.into())
}
}
};
Expand Down