diff --git a/.evergreen/MSRV-Cargo.toml.diff b/.evergreen/MSRV-Cargo.toml.diff index d4185946f..81793c287 100644 --- a/.evergreen/MSRV-Cargo.toml.diff +++ b/.evergreen/MSRV-Cargo.toml.diff @@ -1,8 +1,10 @@ -141c141 +116a117 +> url = "=2.5.2" +144c145 < version = "1.17.0" --- > version = "=1.38.0" -150c150 +153c154 < version = "0.7.0" --- > version = "=0.7.11" diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 2ccc78ee3..b12a69ae7 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -6,7 +6,7 @@ set -o pipefail source .evergreen/env.sh source .evergreen/cargo-test.sh -FEATURE_FLAGS+=("tracing-unstable") +FEATURE_FLAGS+=("tracing-unstable" "cert-key-password") if [ "$OPENSSL" = true ]; then FEATURE_FLAGS+=("openssl-tls") diff --git a/Cargo.toml b/Cargo.toml index 30e187e1c..b7b3a48ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ sync = [] rustls-tls = ["dep:rustls", "dep:rustls-pemfile", "dep:tokio-rustls"] openssl-tls = ["dep:openssl", "dep:openssl-probe", "dep:tokio-openssl"] dns-resolver = ["dep:hickory-resolver", "dep:hickory-proto"] +cert-key-password = ["dep:pem", "dep:pkcs8"] # Enable support for MONGODB-AWS authentication. # This can only be used with the tokio-runtime feature flag. @@ -95,7 +96,9 @@ mongodb-internal-macros = { path = "macros", version = "3.1.0" } num_cpus = { version = "1.13.1", optional = true } openssl = { version = "0.10.38", optional = true } openssl-probe = { version = "0.1.5", optional = true } +pem = { version = "3.0.4", optional = true } percent-encoding = "2.0.0" +pkcs8 = { version = "0.10.2", features = ["encryption", "pkcs5"], optional = true } rand = { version = "0.8.3", features = ["small_rng"] } rayon = { version = "1.5.3", optional = true } rustc_version_runtime = "0.3.0" diff --git a/src/client/options.rs b/src/client/options.rs index e9bed02f1..73374adf2 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -1047,6 +1047,10 @@ pub struct TlsOptions { /// The default value is to error on invalid hostnames. #[cfg(feature = "openssl-tls")] pub allow_invalid_hostnames: Option, + + /// If set, the key in `cert_key_file_path` must be encrypted with this password. + #[cfg(feature = "cert-key-password")] + pub tls_certificate_key_file_password: Option>, } impl TlsOptions { @@ -1064,6 +1068,8 @@ impl TlsOptions { tlscafile: Option<&'a str>, tlscertificatekeyfile: Option<&'a str>, tlsallowinvalidcertificates: Option, + #[cfg(feature = "cert-key-password")] + tlscertificatekeyfilepassword: Option<&'a str>, } let state = TlsOptionsHelper { @@ -1077,6 +1083,11 @@ impl TlsOptions { .as_ref() .map(|s| s.to_str().unwrap()), tlsallowinvalidcertificates: tls_options.allow_invalid_certificates, + #[cfg(feature = "cert-key-password")] + tlscertificatekeyfilepassword: tls_options + .tls_certificate_key_file_password + .as_deref() + .map(|b| std::str::from_utf8(b).unwrap()), }; state.serialize(serializer) } @@ -2126,6 +2137,25 @@ impl ConnectionString { )) } }, + #[cfg(feature = "cert-key-password")] + "tlscertificatekeyfilepassword" => match &mut self.tls { + Some(Tls::Disabled) => { + return Err(ErrorKind::InvalidArgument { + message: "'tlsCertificateKeyFilePassword' can't be set if tls=false".into(), + } + .into()); + } + Some(Tls::Enabled(options)) => { + options.tls_certificate_key_file_password = Some(value.as_bytes().to_vec()); + } + None => { + self.tls = Some(Tls::Enabled( + TlsOptions::builder() + .tls_certificate_key_file_password(value.as_bytes().to_vec()) + .build(), + )) + } + }, "uuidrepresentation" => match value.to_lowercase().as_str() { "csharplegacy" => self.uuid_representation = Some(UuidRepresentation::CSharpLegacy), "javalegacy" => self.uuid_representation = Some(UuidRepresentation::JavaLegacy), diff --git a/src/client/options/test.rs b/src/client/options/test.rs index 3d1f4da9e..502ac04ed 100644 --- a/src/client/options/test.rs +++ b/src/client/options/test.rs @@ -20,7 +20,7 @@ static SKIPPED_TESTS: Lazy> = Lazy::new(|| { "tlsInsecure is parsed correctly", // The driver does not support maxPoolSize=0 "maxPoolSize=0 does not error", - // TODO RUST-226: unskip this test + #[cfg(not(feature = "cert-key-password"))] "Valid tlsCertificateKeyFilePassword is parsed correctly", ]; diff --git a/src/runtime.rs b/src/runtime.rs index f76f9e308..e46605bb2 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -8,6 +8,8 @@ mod acknowledged_message; ))] mod http; mod join_handle; +#[cfg(feature = "cert-key-password")] +mod pem; #[cfg(any(feature = "in-use-encryption", test))] pub(crate) mod process; #[cfg(feature = "dns-resolver")] diff --git a/src/runtime/pem.rs b/src/runtime/pem.rs new file mode 100644 index 000000000..ef3c3109f --- /dev/null +++ b/src/runtime/pem.rs @@ -0,0 +1,30 @@ +use crate::error::{ErrorKind, Result}; + +pub(crate) fn decrypt_private_key(pem_data: &[u8], password: &[u8]) -> Result> { + let pems = pem::parse_many(pem_data).map_err(|error| ErrorKind::InvalidTlsConfig { + message: format!("Could not parse pemfile: {}", error), + })?; + let mut iter = pems + .into_iter() + .filter(|pem| pem.tag() == "ENCRYPTED PRIVATE KEY"); + let encrypted_bytes = match iter.next() { + Some(pem) => pem.into_contents(), + None => { + return Err(ErrorKind::InvalidTlsConfig { + message: "No encrypted private keys found".into(), + } + .into()) + } + }; + let encrypted_key = pkcs8::EncryptedPrivateKeyInfo::try_from(encrypted_bytes.as_slice()) + .map_err(|error| ErrorKind::InvalidTlsConfig { + message: format!("Invalid encrypted private key: {}", error), + })?; + let decrypted_key = + encrypted_key + .decrypt(password) + .map_err(|error| ErrorKind::InvalidTlsConfig { + message: format!("Failed to decrypt private key: {}", error), + })?; + Ok(decrypted_key.as_bytes().to_vec()) +} diff --git a/src/runtime/tls_openssl.rs b/src/runtime/tls_openssl.rs index 5d570c270..cbc431aee 100644 --- a/src/runtime/tls_openssl.rs +++ b/src/runtime/tls_openssl.rs @@ -31,11 +31,7 @@ impl TlsConfig { None => true, }; - let connector = make_openssl_connector(options).map_err(|e| { - Error::from(ErrorKind::InvalidTlsConfig { - message: e.to_string(), - }) - })?; + let connector = make_openssl_connector(options)?; Ok(TlsConfig { connector, @@ -66,25 +62,50 @@ pub(super) async fn tls_connect( Ok(stream) } -fn make_openssl_connector(cfg: TlsOptions) -> std::result::Result { - let mut builder = SslConnector::builder(SslMethod::tls_client())?; +fn make_openssl_connector(cfg: TlsOptions) -> Result { + let openssl_err = |e: ErrorStack| { + Error::from(ErrorKind::InvalidTlsConfig { + message: e.to_string(), + }) + }; + + let mut builder = SslConnector::builder(SslMethod::tls_client()).map_err(openssl_err)?; let TlsOptions { allow_invalid_certificates, ca_file_path, cert_key_file_path, allow_invalid_hostnames: _, + #[cfg(feature = "cert-key-password")] + tls_certificate_key_file_password, } = cfg; if let Some(true) = allow_invalid_certificates { builder.set_verify(SslVerifyMode::NONE); } if let Some(path) = ca_file_path { - builder.set_ca_file(path)?; + builder.set_ca_file(path).map_err(openssl_err)?; } if let Some(path) = cert_key_file_path { - builder.set_certificate_file(path.clone(), SslFiletype::PEM)?; - builder.set_private_key_file(path, SslFiletype::PEM)?; + builder + .set_certificate_file(path.clone(), SslFiletype::PEM) + .map_err(openssl_err)?; + // Inner fn so the cert-key-password path can early return + let handle_private_key = || -> Result<()> { + #[cfg(feature = "cert-key-password")] + if let Some(key_pw) = tls_certificate_key_file_password { + let contents = std::fs::read(&path)?; + let key_bytes = super::pem::decrypt_private_key(&contents, &key_pw)?; + let key = + openssl::pkey::PKey::private_key_from_der(&key_bytes).map_err(openssl_err)?; + builder.set_private_key(&key).map_err(openssl_err)?; + return Ok(()); + } + builder + .set_private_key_file(path, SslFiletype::PEM) + .map_err(openssl_err) + }; + handle_private_key()?; } Ok(builder.build()) diff --git a/src/runtime/tls_rustls.rs b/src/runtime/tls_rustls.rs index 6dfdecd05..c60b2af7c 100644 --- a/src/runtime/tls_rustls.rs +++ b/src/runtime/tls_rustls.rs @@ -104,6 +104,13 @@ fn make_rustls_config(cfg: TlsOptions) -> Result { file.rewind()?; let key = loop { + #[cfg(feature = "cert-key-password")] + if let Some(key_pw) = cfg.tls_certificate_key_file_password.as_deref() { + use std::io::Read; + let mut contents = vec![]; + file.read_to_end(&mut contents)?; + break rustls::PrivateKey(super::pem::decrypt_private_key(&contents, key_pw)?); + } match read_one(&mut file) { Ok(Some(Item::PKCS8Key(bytes))) | Ok(Some(Item::RSAKey(bytes))) => { break rustls::PrivateKey(bytes)